Adding Custom Modules#

Overview#

FluxVLA employs a Registry mechanism to manage various module types, enabling flexible addition of custom implementations. The extensible module types include:

Registry

Description

Code Location

HEADS

Action prediction heads (e.g., Diffusion Policy, VAE, etc.)

fluxvla/models/heads/

VLAS

VLA models

fluxvla/models/vlas/

LLM_BACKBONES

Language model backbones

fluxvla/models/backbones/llms/

VISION_BACKBONES

Vision backbones

fluxvla/models/backbones/visions/

VLM_BACKBONES

Vision-language model backbones

fluxvla/models/backbones/vlms/

PROJECTORS

Feature projectors

fluxvla/models/projectors/

DATASETS

Datasets

fluxvla/datasets/

TRANSFORMS

Data transforms

fluxvla/transforms/

COLLATORS

Data collators

fluxvla/collators/

TOKENIZERS

Tokenizers

fluxvla/tokenizers/

RUNNERS

Training/inference runners

fluxvla/engines/runners/

OPERATORS

Robot operators

fluxvla/engines/operators/

METRICS

Evaluation metrics

fluxvla/engines/metrics/

PROCESSORS

Processors

fluxvla/engines/processors/

All modules are managed through their corresponding registries and can be referenced in configuration files via the type field. The procedure for adding any module type is consistent: create a file → register the module → import → reference in configuration.

This tutorial uses a VLA Head (action prediction head) as an example to demonstrate how to add a custom module.

The Head-related code is located in fluxvla/models/heads/:

fluxvla/models/heads/
├── flow_matching_head.py         # FlowMatchingHead (DiT)
├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
├── llava_action_head.py          # LlavaActionHead (Transformer)
└── openvla_head.py              # OpenVLAHead (Token-based)

Design references for existing Heads:

Head Type

Action Prediction Method

Characteristics

FlowMatchingHead

Flow Matching (DiT)

Generative, multi-step denoising, suitable for complex action distributions

LlavaActionHead

Transformer decoding

Autoregressive prediction, supports variable-length trajectories

OpenVLAHead

Token discretization

Token prediction based on VLM, no additional network required


Adding a New VLA Head#

Step 1: Create the Head File#

Create a new file under fluxvla/models/heads/, for example my_action_head.py:

from typing import Callable, Dict, Optional

import torch
import torch.nn as nn

from fluxvla.engines.utils.root import HEADS


@HEADS.register_module()
class MyActionHead(nn.Module):
    """自定义动作预测头。"""

    def __init__(self,
                 hidden_size: int = 1024,
                 state_dim: int = 64,
                 action_dim: int = 32,
                 input_embedding_dim: int = 1536,
                 traj_length: int = 10,
                 ori_action_dim: Optional[int] = None):
        super().__init__()
        self.hidden_size = hidden_size
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.traj_length = traj_length
        self.ori_action_dim = ori_action_dim or action_dim

        # 状态编码器:将机器人状态映射到隐空间
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size))

        # 动作解码器:从特征预测动作序列
        self.action_decoder = nn.Sequential(
            nn.Linear(input_embedding_dim + hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, traj_length * action_dim))

    def forward(self,
                input_features: torch.Tensor,
                states: torch.Tensor,
                attention_mask: torch.Tensor,
                actions: torch.Tensor,
                action_masks: torch.Tensor,
                **kwargs) -> Dict:
        """前向传播(训练时调用)。

        Args:
            input_features: VLM 输出特征, shape=(B, seq_len, embed_dim)
            states: 机器人状态, shape=(B, state_dim)
            attention_mask: 注意力掩码, shape=(B, seq_len)
            actions: 目标动作, shape=(B, traj_length, action_dim)
            action_masks: 动作掩码, shape=(B, traj_length)

        Returns:
            dict: {'pred_actions': ..., 'loss': ...}
        """
        B = input_features.shape[0]

        # 1. 编码状态
        state_emb = self.state_encoder(states)  # (B, hidden_size)

        # 2. 池化 VLM 特征(取最后一个有效 token)
        seq_lengths = attention_mask.sum(dim=1).long() - 1
        pooled = input_features[range(B), seq_lengths]  # (B, embed_dim)

        # 3. 拼接并预测动作
        combined = torch.cat([pooled, state_emb], dim=-1)
        pred_flat = self.action_decoder(combined)  # (B, traj_len * action_dim)
        pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)

        # 4. 计算损失
        loss = nn.functional.mse_loss(
            pred_actions[action_masks.bool()],
            actions[action_masks.bool()])

        return {'pred_actions': pred_actions, 'loss': loss}

    def predict_action(self,
                       input_features: torch.Tensor,
                       states: torch.Tensor,
                       attention_mask: torch.Tensor,
                       **kwargs) -> torch.Tensor:
        """动作预测(推理时调用)。

        Returns:
            torch.Tensor: 预测动作序列, shape=(B, traj_length, action_dim)
        """
        B = input_features.shape[0]
        state_emb = self.state_encoder(states)
        seq_lengths = attention_mask.sum(dim=1).long() - 1
        pooled = input_features[range(B), seq_lengths]
        combined = torch.cat([pooled, state_emb], dim=-1)
        pred_flat = self.action_decoder(combined)
        pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
        return pred_actions

    def get_fsdp_wrapping_policy(self) -> Callable:
        """返回 FSDP 包装策略。"""
        return None  # 简单模型无需特殊的 FSDP 分片策略

Step 2: Understand the Head Interface Contract#

A VLA Head must implement two core methods:

Method

Scenario

Input

Output

forward(...)

Training

Features + states + target actions

{'pred_actions': Tensor, 'loss': Tensor}

predict_action(...)

Inference

Features + states

Tensor (predicted action sequence)

forward() input parameters:

Parameter

Type

Description

input_features

torch.Tensor

VLM output features, shape=(B, seq_len, embed_dim)

states

torch.Tensor

Robot states, shape=(B, state_dim)

attention_mask

torch.Tensor

Attention mask, shape=(B, seq_len)

actions

torch.Tensor

Target actions, shape=(B, traj_length, action_dim)

action_masks

torch.Tensor

Action masks, shape=(B, traj_length)

forward() return value: Must return a dictionary containing a 'loss' key, e.g., {'pred_actions': Tensor, 'loss': Tensor}.

predict_action() return value: Returns the predicted action tensor with shape=(B, traj_length, action_dim).

Step 3: Register and Import#

Add the import statement in fluxvla/models/heads/__init__.py:

from .flow_matching_head import FlowMatchingHead  # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead  # noqa: F401, F403
from .llava_action_head import LlavaActionHead  # noqa: F401, F403
from .my_action_head import MyActionHead  # noqa: F401, F403   ← 新增
from .openvla_head import OpenVLAHead  # noqa: F401, F403

Step 4: Use in Configuration File#

Simply modify the type and corresponding parameters in the model.vla_head dictionary:

model = dict(
    type='LlavaVLA',                    # VLA 不变
    vlm_backbone=dict(...),             # backbone 不变
    vla_head=dict(
        type='MyActionHead',            # ← 换成你的 Head
        hidden_size=1024,
        state_dim=64,
        action_dim=32,
        input_embedding_dim=1536,
        traj_length=10,
        ori_action_dim=14),
    ...)

Core Checklist#

After adding a new VLA Head, verify each item in the following checklist:

  • The Head file has been created under the fluxvla/models/heads/ directory

  • The module is registered using the @HEADS.register_module() decorator

  • An import statement has been added in fluxvla/models/heads/__init__.py

  • The forward() method is implemented and returns a dictionary containing a 'loss' key

  • The predict_action() method is implemented and returns an action tensor

  • The type in the configuration file matches the registered class name exactly

  • The parameter names in the configuration file match the parameter names of the __init__ method


Frequently Asked Questions#

Q: What configuration changes are required after adding a new Head?#

Only the type and corresponding parameters in the model.vla_head dictionary need to be modified. The VLA model and backbone configurations do not require any changes.

Q: What is inference_model?#

inference_model is an optional inference-specific model configuration. It is used when a different model setup is required during inference (e.g., using FlowMatchingInferenceHead instead of FlowMatchingHead to enable CUDA Graph acceleration):

model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)

If inference_model is not defined, the model configuration is used by default during inference.

Q: How can I ensure my Head is compatible with FSDP training?#

If your Head contains a large number of parameters, you need to implement the get_fsdp_wrapping_policy() method to specify which submodules should be individually sharded by FSDP:

def get_fsdp_wrapping_policy(self):
    from functools import partial
    from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
    return partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={MyTransformerBlock})

If the Head has a small number of parameters, returning None is sufficient.