Adding Custom Modules#
Overview#
FluxVLA employs a Registry mechanism to manage various module types, enabling flexible addition of custom implementations. The extensible module types include:
Registry |
Description |
Code Location |
|---|---|---|
|
Action prediction heads (e.g., Diffusion Policy, VAE, etc.) |
|
|
VLA models |
|
|
Language model backbones |
|
|
Vision backbones |
|
|
Vision-language model backbones |
|
|
Feature projectors |
|
|
Datasets |
|
|
Data transforms |
|
|
Data collators |
|
|
Tokenizers |
|
|
Training/inference runners |
|
|
Robot operators |
|
|
Evaluation metrics |
|
|
Processors |
|
All modules are managed through their corresponding registries and can be referenced in configuration files via the type field. The procedure for adding any module type is consistent: create a file → register the module → import → reference in configuration.
This tutorial uses a VLA Head (action prediction head) as an example to demonstrate how to add a custom module.
The Head-related code is located in fluxvla/models/heads/:
fluxvla/models/heads/
├── flow_matching_head.py # FlowMatchingHead (DiT)
├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
├── llava_action_head.py # LlavaActionHead (Transformer)
└── openvla_head.py # OpenVLAHead (Token-based)
Design references for existing Heads:
Head Type |
Action Prediction Method |
Characteristics |
|---|---|---|
|
Flow Matching (DiT) |
Generative, multi-step denoising, suitable for complex action distributions |
|
Transformer decoding |
Autoregressive prediction, supports variable-length trajectories |
|
Token discretization |
Token prediction based on VLM, no additional network required |
Adding a New VLA Head#
Step 1: Create the Head File#
Create a new file under fluxvla/models/heads/, for example my_action_head.py:
from typing import Callable, Dict, Optional
import torch
import torch.nn as nn
from fluxvla.engines.utils.root import HEADS
@HEADS.register_module()
class MyActionHead(nn.Module):
"""自定义动作预测头。"""
def __init__(self,
hidden_size: int = 1024,
state_dim: int = 64,
action_dim: int = 32,
input_embedding_dim: int = 1536,
traj_length: int = 10,
ori_action_dim: Optional[int] = None):
super().__init__()
self.hidden_size = hidden_size
self.state_dim = state_dim
self.action_dim = action_dim
self.traj_length = traj_length
self.ori_action_dim = ori_action_dim or action_dim
# 状态编码器:将机器人状态映射到隐空间
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size))
# 动作解码器:从特征预测动作序列
self.action_decoder = nn.Sequential(
nn.Linear(input_embedding_dim + hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, traj_length * action_dim))
def forward(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
actions: torch.Tensor,
action_masks: torch.Tensor,
**kwargs) -> Dict:
"""前向传播(训练时调用)。
Args:
input_features: VLM 输出特征, shape=(B, seq_len, embed_dim)
states: 机器人状态, shape=(B, state_dim)
attention_mask: 注意力掩码, shape=(B, seq_len)
actions: 目标动作, shape=(B, traj_length, action_dim)
action_masks: 动作掩码, shape=(B, traj_length)
Returns:
dict: {'pred_actions': ..., 'loss': ...}
"""
B = input_features.shape[0]
# 1. 编码状态
state_emb = self.state_encoder(states) # (B, hidden_size)
# 2. 池化 VLM 特征(取最后一个有效 token)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths] # (B, embed_dim)
# 3. 拼接并预测动作
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined) # (B, traj_len * action_dim)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
# 4. 计算损失
loss = nn.functional.mse_loss(
pred_actions[action_masks.bool()],
actions[action_masks.bool()])
return {'pred_actions': pred_actions, 'loss': loss}
def predict_action(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs) -> torch.Tensor:
"""动作预测(推理时调用)。
Returns:
torch.Tensor: 预测动作序列, shape=(B, traj_length, action_dim)
"""
B = input_features.shape[0]
state_emb = self.state_encoder(states)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths]
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
return pred_actions
def get_fsdp_wrapping_policy(self) -> Callable:
"""返回 FSDP 包装策略。"""
return None # 简单模型无需特殊的 FSDP 分片策略
Step 2: Understand the Head Interface Contract#
A VLA Head must implement two core methods:
Method |
Scenario |
Input |
Output |
|---|---|---|---|
|
Training |
Features + states + target actions |
|
|
Inference |
Features + states |
|
forward() input parameters:
Parameter |
Type |
Description |
|---|---|---|
|
|
VLM output features, shape=(B, seq_len, embed_dim) |
|
|
Robot states, shape=(B, state_dim) |
|
|
Attention mask, shape=(B, seq_len) |
|
|
Target actions, shape=(B, traj_length, action_dim) |
|
|
Action masks, shape=(B, traj_length) |
forward() return value: Must return a dictionary containing a 'loss' key, e.g., {'pred_actions': Tensor, 'loss': Tensor}.
predict_action() return value: Returns the predicted action tensor with shape=(B, traj_length, action_dim).
Step 3: Register and Import#
Add the import statement in fluxvla/models/heads/__init__.py:
from .flow_matching_head import FlowMatchingHead # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead # noqa: F401, F403
from .llava_action_head import LlavaActionHead # noqa: F401, F403
from .my_action_head import MyActionHead # noqa: F401, F403 ← 新增
from .openvla_head import OpenVLAHead # noqa: F401, F403
Step 4: Use in Configuration File#
Simply modify the type and corresponding parameters in the model.vla_head dictionary:
model = dict(
type='LlavaVLA', # VLA 不变
vlm_backbone=dict(...), # backbone 不变
vla_head=dict(
type='MyActionHead', # ← 换成你的 Head
hidden_size=1024,
state_dim=64,
action_dim=32,
input_embedding_dim=1536,
traj_length=10,
ori_action_dim=14),
...)
Core Checklist#
After adding a new VLA Head, verify each item in the following checklist:
The Head file has been created under the
fluxvla/models/heads/directoryThe module is registered using the
@HEADS.register_module()decoratorAn import statement has been added in
fluxvla/models/heads/__init__.pyThe
forward()method is implemented and returns a dictionary containing a'loss'keyThe
predict_action()method is implemented and returns an action tensorThe
typein the configuration file matches the registered class name exactlyThe parameter names in the configuration file match the parameter names of the
__init__method
Frequently Asked Questions#
Q: What configuration changes are required after adding a new Head?#
Only the type and corresponding parameters in the model.vla_head dictionary need to be modified. The VLA model and backbone configurations do not require any changes.
Q: What is inference_model?#
inference_model is an optional inference-specific model configuration. It is used when a different model setup is required during inference (e.g., using FlowMatchingInferenceHead instead of FlowMatchingHead to enable CUDA Graph acceleration):
model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)
If inference_model is not defined, the model configuration is used by default during inference.
Q: How can I ensure my Head is compatible with FSDP training?#
If your Head contains a large number of parameters, you need to implement the get_fsdp_wrapping_policy() method to specify which submodules should be individually sharded by FSDP:
def get_fsdp_wrapping_policy(self):
from functools import partial
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
return partial(
transformer_auto_wrap_policy,
transformer_layer_cls={MyTransformerBlock})
If the Head has a small number of parameters, returning None is sufficient.