Adding Custom Models#

Overview#

In FluxVLA, a complete VLA model (Vision-Language-Action Model) is composed of multiple submodules:

Submodule

Registry

Purpose

Examples

VLA

VLAS

Top-level model that orchestrates all submodules

LlavaVLA, OpenVLA, PI0FlowMatching

VLM Backbone

VLM_BACKBONES

Vision-language multimodal understanding

EagleBackbone, PaliGemma, QWen2_5VL

Vision Backbone

VISION_BACKBONES

Pure visual feature extraction

SigLIPViTBackbone, DinoSigLIPViTBackbone

LLM Backbone

LLM_BACKBONES

Pure language model

LLaMa2LLMBackbone, GemmaLLMBackbone

Projector

PROJECTORS

Feature space mapping (vision → language)

MLPProjector, LinearProjector, FusedMLPProjector

VLA Head

HEADS

Action prediction head

FlowMatchingHead, LlavaActionHead, OpenVLAHead

This tutorial provides guidance on how to add custom model components to FluxVLA — whether it is an entirely new VLA architecture, a new VLM backbone, or a new action prediction head.


Architecture Overview#

Model Inheritance Hierarchy#

BaseVLA (ABC)
├── OpenVLA                  # 基于 token 的动作预测
│   └── LlavaVLA             # 支持 VLM backbone 的连续动作预测
└── PI0FlowMatching          # Flow Matching 生成式动作预测
    └── PI05FlowMatching     # PI0.5 改进版

Directory Structure#

All model-related code is located under fluxvla/models/:

fluxvla/models/
├── __init__.py              # 导出 backbones、heads、projectors、vlas
├── vlas/
│   ├── __init__.py          # 导入所有 VLA(触发注册)
│   ├── base_vla.py          # VLA 抽象基类
│   ├── open_vla.py          # OpenVLA
│   ├── llava_vla.py         # LlavaVLA(继承 OpenVLA)
│   ├── pi0_flowmatching.py  # PI0FlowMatching
│   └── pi05_flowmatching.py # PI05FlowMatching(继承 PI0)
├── backbones/
│   ├── vlms/                # VLM backbones
│   │   ├── eagle.py         # EagleBackbone
│   │   ├── paligemma.py     # PaliGemma
│   │   └── qwen2_5_vl.py   # QWen2_5VL
│   ├── visions/             # 视觉 backbones
│   │   ├── siglip_vit.py    # SigLIPViTBackbone
│   │   └── dinosiglip_vit.py # DinoSigLIPViTBackbone
│   └── llms/                # LLM backbones
│       ├── llama2.py        # LLaMa2LLMBackbone
│       ├── gemma.py         # GemmaLLMBackbone
│       └── qwen2.py         # Qwen2LLMBackbone
├── heads/
│   ├── flow_matching_head.py         # FlowMatchingHead (DiT)
│   ├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
│   ├── llava_action_head.py          # LlavaActionHead (Transformer)
│   └── openvla_head.py              # OpenVLAHead (Token-based)
├── projectors/
│   ├── mlp_projector.py     # MLPProjector
│   ├── linear_projector.py  # LinearProjector
│   └── fused_projector.py   # FusedMLPProjector
└── blocks/
    └── cross_attention_dit.py # DiT Transformer block

Registration and Construction Mechanism#

FluxVLA employs a unified Registry pattern to manage all model components. Each component is registered via a decorator, and configuration files reference components through the type field.

Registration:

from fluxvla.engines.utils.root import VLAS

@VLAS.register_module()
class MyVLA(BaseVLA):
    ...

Construction (automatically performed within BaseVLA.__init__):

# BaseVLA.__init__ 中的子模块构建逻辑
if vlm_backbone is not None:
    self.vlm_backbone = build_vlm_backbone_from_cfg(vlm_backbone)
if vision_backbone is not None:
    self.vision_backbone = build_vision_backbone_from_cfg(vision_backbone)
if llm_backbone is not None:
    self.llm_backbone = build_llm_backbone_from_cfg(llm_backbone)
if projector is not None:
    self.projector = build_projector_from_cfg(projector)
if vla_head is not None:
    self.vla_head = build_head_from_cfg(vla_head)

Model construction during training (BaseTrainRunner.__init__):

self.vla = build_vla_from_cfg(cfg.model)   # 读取 config 中的 model 字典

Therefore, you only need to specify the type and the corresponding parameters in the configuration file; the framework will automatically look up and instantiate all submodules through the registry.


Adding a New VLA Model#

When you need to implement an entirely new VLA architecture (e.g., a different training paradigm or a different modality fusion strategy), you should add a new top-level VLA model.

Step 1: Create the Model File#

Create a new file under fluxvla/models/vlas/, for example my_vla.py:

from typing import Dict, Optional

import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast

from fluxvla.engines.utils.root import VLAS
from .base_vla import BaseVLA


@VLAS.register_module()
class MyVLA(BaseVLA):
    """自定义 VLA 模型示例。"""

    def __init__(self,
                 my_custom_param: int = 256,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        self.my_custom_param = my_custom_param
        # BaseVLA 已经根据 config 自动构建了:
        #   self.vlm_backbone  (如果 config 中指定了 vlm_backbone)
        #   self.vision_backbone (如果 config 中指定了 vision_backbone)
        #   self.llm_backbone  (如果 config 中指定了 llm_backbone)
        #   self.projector     (如果 config 中指定了 projector)
        #   self.vla_head      (如果 config 中指定了 vla_head)

    def forward(self,
                images: torch.Tensor,
                lang_tokens: torch.LongTensor,
                lang_masks: torch.Tensor,
                img_masks: torch.Tensor,
                states: torch.Tensor,
                actions: torch.Tensor,
                action_masks: torch.Tensor,
                **kwargs) -> Dict:
        """前向传播(训练时调用)。

        Args:
            images: 图像张量, shape=(B, N_cam, C, H, W)
            lang_tokens: 语言 token, shape=(B, seq_len)
            lang_masks: 语言掩码, shape=(B, seq_len)
            img_masks: 图像掩码, shape=(B, N_cam)
            states: 机器人状态, shape=(B, state_dim)
            actions: 目标动作序列, shape=(B, traj_len, action_dim)
            action_masks: 动作掩码, shape=(B, traj_len)

        Returns:
            dict: 包含 'loss' 键的字典
        """
        # 1. 提取视觉-语言特征
        features, attention_mask, _ = self.vlm_backbone(
            images, lang_tokens, img_masks, lang_masks)

        # 2. 通过动作头预测动作并计算损失
        output = self.vla_head(
            input_features=features,
            states=states,
            attention_mask=attention_mask,
            actions=actions,
            action_masks=action_masks)

        return output  # {'pred_actions': ..., 'loss': ...}

    def predict_action(self,
                       images: torch.Tensor,
                       lang_tokens: torch.LongTensor,
                       lang_masks: torch.Tensor,
                       img_masks: torch.Tensor,
                       states: torch.Tensor,
                       **kwargs) -> torch.Tensor:
        """动作预测(推理时调用)。

        Returns:
            torch.Tensor: 预测的动作序列
        """
        features, attention_mask, _ = self.vlm_backbone(
            images, lang_tokens, img_masks, lang_masks)

        actions = self.vla_head.predict_action(
            input_features=features,
            states=states,
            attention_mask=attention_mask)

        return actions

    def get_fsdp_wrapping_policy(self):
        """返回 FSDP 包装策略,指定哪些子模块应单独分片。"""
        from functools import partial
        from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy

        def should_wrap(module, recurse, *args, **kwargs):
            if recurse:
                return True
            # 指定需要单独分片的模块类型
            return isinstance(module, (
                nn.TransformerEncoderLayer,
                nn.TransformerDecoderLayer,
            ))

        return partial(lambda_auto_wrap_policy, lambda_fn=should_wrap)

Step 2: Understand the BaseVLA Interface#

BaseVLA is the abstract base class for all VLA models. It provides common initialization logic and pretrained weight loading functionality. Subclasses must implement the following abstract methods:

Method

Purpose

Invocation Context

forward(...)

Forward pass; computes the training loss

Each step of the training loop

get_fsdp_wrapping_policy()

Returns the FSDP sharding policy

When using FSDPTrainRunner

Functionality already provided by the base class (subclasses typically do not need to override these):

  • Automatic submodule construction: Automatically creates backbone, projector, and head based on the configuration

  • Pretrained weight loading: Loads pretrained models via pretrained_name_or_path and name_mapping

  • Module freezing: Controls parameter freezing via freeze_vlm_backbone, freeze_projector, etc.

  • from_pretrained(): Loads weights from a checkpoint

Step 3: Register and Import#

Add the import statement in fluxvla/models/vlas/__init__.py:

from .llava_vla import LlavaVLA  # noqa: F401, F403
from .my_vla import MyVLA  # noqa: F401, F403               ← 新增
from .open_vla import OpenVLA  # noqa: F401, F403
from .pi0_flowmatching import PI0FlowMatching  # noqa: F401, F403
from .pi05_flowmatching import PI05FlowMatching  # noqa: F401, F403

Step 4: Use in a Configuration File#

model = dict(
    type='MyVLA',                    # ← 你注册的 VLA 类名
    my_custom_param=256,             # ← 自定义参数
    pretrained_name_or_path='./checkpoints/my_pretrained',
    vlm_backbone=dict(
        type='EagleBackbone',
        vlm_path='fluxvla/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=64,
        action_dim=32,
        ori_action_dim=14,
        ...),
    freeze_vlm_backbone=False,
    freeze_projector=False)

Adding a New VLM Backbone#

When you need to integrate a new vision-language pretrained model (e.g., a new version of InternVL, LLaVA-Next, etc.) as a backbone, you should add a new VLM Backbone.

Step 1: Create the Backbone File#

Create a new file under fluxvla/models/backbones/vlms/, for example my_vlm.py:

from typing import Callable, Dict, Optional, Sequence, Type

import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast

from fluxvla.engines.utils.root import VLM_BACKBONES


@VLM_BACKBONES.register_module()
class MyVLMBackbone(nn.Module):
    """自定义 VLM Backbone。"""

    def __init__(self,
                 vlm_path: str,
                 vlm_config: Optional[Dict] = None,
                 select_layer: int = -1,
                 dtype: str = 'float32'):
        super().__init__()
        # 加载预训练 VLM 模型
        # 例如使用 transformers 的 AutoModel:
        from transformers import AutoModel, AutoConfig
        if vlm_config is not None:
            config = AutoConfig.from_pretrained(vlm_path, **vlm_config)
        else:
            config = AutoConfig.from_pretrained(vlm_path)
        self.vlm = AutoModel.from_pretrained(vlm_path, config=config)
        self.select_layer = select_layer
        self.config = config

    def forward(self,
                images: torch.Tensor,
                lang_tokens: torch.LongTensor,
                img_masks: torch.Tensor,
                lang_masks: torch.Tensor
                ) -> tuple:
        """前向传播。

        Args:
            images: 图像张量, shape=(B, N_cam, C, H, W)
            lang_tokens: 语言 token, shape=(B, seq_len)
            img_masks: 图像掩码, shape=(B, N_cam)
            lang_masks: 语言掩码, shape=(B, seq_len)

        Returns:
            tuple: (features, attention_mask, None)
                - features: VLM 输出特征, shape=(B, total_seq_len, hidden_dim)
                - attention_mask: 注意力掩码
        """
        # 实现你的多模态特征提取逻辑
        outputs = self.vlm(
            input_ids=lang_tokens,
            pixel_values=images,
            attention_mask=lang_masks,
        )
        features = outputs.last_hidden_state
        return features, lang_masks, None

    def enable_gradient_checkpointing(self):
        """启用梯度检查点以节省显存。"""
        self.vlm.gradient_checkpointing_enable()

    def get_fsdp_wrapping_policy(self) -> Callable:
        """返回 FSDP 包装策略。"""
        from functools import partial
        from torch.distributed.fsdp.wrap import (
            lambda_auto_wrap_policy, transformer_auto_wrap_policy)

        # 获取 VLM 中的 Transformer 层类型
        layer_cls = self._get_transformer_layer_cls()
        return partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={layer_cls})

    def _get_transformer_layer_cls(self) -> Type[nn.Module]:
        """返回 Transformer 层的类型,用于 FSDP 包装。"""
        # 根据你的 VLM 架构返回对应的层类型
        raise NotImplementedError

Step 2: Understand the Input/Output Conventions#

The forward method of the VLM Backbone must adhere to the following conventions to ensure compatibility with the upstream VLA model:

Input Parameters:

Parameter

Type

Description

images

torch.Tensor

Image tensor; shape varies by implementation

lang_tokens

torch.LongTensor

Language token IDs

img_masks

torch.Tensor

Image validity mask

lang_masks

torch.Tensor

Language validity mask

Return Values:

return (features, attention_mask, extra)

Return Value

Description

features

Multimodal fused features, shape=(B, seq_len, hidden_dim)

attention_mask

Attention mask, shape=(B, seq_len)

extra

Additional information (typically None)

Step 3: Register and Import#

Add the import statement in fluxvla/models/backbones/vlms/__init__.py:

from .eagle import EagleBackbone, EagleInferenceBackbone  # noqa: F401, F403
from .my_vlm import MyVLMBackbone  # noqa: F401, F403        ← 新增
from .paligemma import PaliGemma  # noqa: F401, F403
from .qwen2_5_vl import QWen2_5VL  # noqa: F401, F403

Adding a New VLA Head#

When you need to employ a new action prediction method (e.g., Diffusion Policy, VAE, VQ-VAE, etc.), you should add a new VLA Head.

Step 1: Create the Head File#

Create a new file under fluxvla/models/heads/, for example my_action_head.py:

from typing import Callable, Dict, Optional

import torch
import torch.nn as nn

from fluxvla.engines.utils.root import HEADS


@HEADS.register_module()
class MyActionHead(nn.Module):
    """自定义动作预测头。"""

    def __init__(self,
                 hidden_size: int = 1024,
                 state_dim: int = 64,
                 action_dim: int = 32,
                 input_embedding_dim: int = 1536,
                 traj_length: int = 10,
                 ori_action_dim: Optional[int] = None):
        super().__init__()
        self.hidden_size = hidden_size
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.traj_length = traj_length
        self.ori_action_dim = ori_action_dim or action_dim

        # 状态编码器:将机器人状态映射到隐空间
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size))

        # 动作解码器:从特征预测动作序列
        self.action_decoder = nn.Sequential(
            nn.Linear(input_embedding_dim + hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, traj_length * action_dim))

    def forward(self,
                input_features: torch.Tensor,
                states: torch.Tensor,
                attention_mask: torch.Tensor,
                actions: torch.Tensor,
                action_masks: torch.Tensor,
                **kwargs) -> Dict:
        """前向传播(训练时调用)。

        Args:
            input_features: VLM 输出特征, shape=(B, seq_len, embed_dim)
            states: 机器人状态, shape=(B, state_dim)
            attention_mask: 注意力掩码, shape=(B, seq_len)
            actions: 目标动作, shape=(B, traj_length, action_dim)
            action_masks: 动作掩码, shape=(B, traj_length)

        Returns:
            dict: {'pred_actions': ..., 'loss': ...}
        """
        B = input_features.shape[0]

        # 1. 编码状态
        state_emb = self.state_encoder(states)  # (B, hidden_size)

        # 2. 池化 VLM 特征(取最后一个有效 token)
        seq_lengths = attention_mask.sum(dim=1).long() - 1
        pooled = input_features[range(B), seq_lengths]  # (B, embed_dim)

        # 3. 拼接并预测动作
        combined = torch.cat([pooled, state_emb], dim=-1)
        pred_flat = self.action_decoder(combined)  # (B, traj_len * action_dim)
        pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)

        # 4. 计算损失
        loss = nn.functional.mse_loss(
            pred_actions[action_masks.bool()],
            actions[action_masks.bool()])

        return {'pred_actions': pred_actions, 'loss': loss}

    def predict_action(self,
                       input_features: torch.Tensor,
                       states: torch.Tensor,
                       attention_mask: torch.Tensor,
                       **kwargs) -> torch.Tensor:
        """动作预测(推理时调用)。

        Returns:
            torch.Tensor: 预测动作序列, shape=(B, traj_length, action_dim)
        """
        B = input_features.shape[0]
        state_emb = self.state_encoder(states)
        seq_lengths = attention_mask.sum(dim=1).long() - 1
        pooled = input_features[range(B), seq_lengths]
        combined = torch.cat([pooled, state_emb], dim=-1)
        pred_flat = self.action_decoder(combined)
        pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
        return pred_actions

    def get_fsdp_wrapping_policy(self) -> Callable:
        """返回 FSDP 包装策略。"""
        return None  # 简单模型无需特殊的 FSDP 分片策略

Step 2: Understand the Head Interface Conventions#

A VLA Head must implement two core methods:

Method

Scenario

Input

Output

forward(...)

Training

Features + states + target actions

{'pred_actions': Tensor, 'loss': Tensor}

predict_action(...)

Inference

Features + states

Tensor (predicted action sequence)

Design references for existing Heads:

Head Type

Action Prediction Approach

Characteristics

FlowMatchingHead

Flow Matching (DiT)

Generative, multi-step denoising, suitable for complex action distributions

LlavaActionHead

Transformer decoding

Autoregressive prediction, supports variable-length trajectories

OpenVLAHead

Token discretization

Token prediction based on VLM, no additional network required

Step 3: Register and Import#

Add the import statement in fluxvla/models/heads/__init__.py:

from .flow_matching_head import FlowMatchingHead  # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead  # noqa: F401, F403
from .llava_action_head import LlavaActionHead  # noqa: F401, F403
from .my_action_head import MyActionHead  # noqa: F401, F403   ← 新增
from .openvla_head import OpenVLAHead  # noqa: F401, F403

Adding a New Projector#

When the output dimension of your vision backbone does not match the input dimension of the language model, a Projector is required to perform feature space mapping.

from fluxvla.engines.utils.root import PROJECTORS

@PROJECTORS.register_module()
class MyProjector(nn.Module):
    """自定义特征投影器。"""

    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.projector = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim))

    def forward(self, input_features: torch.Tensor) -> torch.Tensor:
        return self.projector(input_features)

After adding the import in fluxvla/models/projectors/__init__.py, the projector can be used via configuration:

model = dict(
    type='OpenVLA',
    projector=dict(
        type='MyProjector',
        in_dim=2176,
        out_dim=4096),
    ...)

Composing Models in Configuration Files#

The model configuration in FluxVLA adopts a nested dictionary approach: the top level specifies the VLA type, and submodule configurations are nested within. BaseVLA automatically constructs all submodules based on the configuration.

Configuration Structure#

model = dict(
    type='...',                     # VLA 类名 → VLAS 注册表
    pretrained_name_or_path='...',  # 预训练权重路径
    vlm_backbone=dict(              # VLM backbone → VLM_BACKBONES 注册表
        type='...',
        ...),
    # 或者使用 vision_backbone + llm_backbone + projector 的组合:
    # vision_backbone=dict(type='...', ...),   → VISION_BACKBONES
    # llm_backbone=dict(type='...', ...),      → LLM_BACKBONES
    # projector=dict(type='...', ...),         → PROJECTORS
    vla_head=dict(                  # VLA Head → HEADS 注册表
        type='...',
        ...),
    freeze_vlm_backbone=False,
    freeze_projector=False,
    name_mapping={...},             # 预训练权重的键名映射
)

Two Backbone Composition Modes#

FluxVLA supports two backbone composition modes, which are mutually exclusive:

Mode 1: Using a VLM Backbone (Recommended)

A VLM backbone is an end-to-end vision-language model (e.g., Eagle, PaliGemma) that has a built-in vision encoder and language model:

model = dict(
    type='LlavaVLA',
    vlm_backbone=dict(type='EagleBackbone', vlm_path='...'),
    vla_head=dict(type='FlowMatchingHead', ...),
    ...)

Mode 2: Using Vision + LLM + Projector

Specify the vision encoder, language model, and projection layer separately:

model = dict(
    type='OpenVLA',
    vision_backbone=dict(type='DinoSigLIPViTBackbone', ...),
    llm_backbone=dict(type='LLaMa2LLMBackbone', ...),
    projector=dict(type='FusedMLPProjector', ...),
    vla_head=dict(type='OpenVLAHead', ...),
    ...)

name_mapping for Pretrained Weight Mapping#

When the parameter names in the pretrained model differ from the module names in FluxVLA, use name_mapping to establish the mapping relationship:

name_mapping={
    'vlm_backbone.vlm': 'backbone.eagle_model',
    'vla_head': 'action_head'
}

The above mapping indicates that parameters under backbone.eagle_model.* in the pretrained checkpoint will be loaded into vlm_backbone.vlm.*, and parameters under action_head.* will be loaded into vla_head.*.


Complete Examples: Using Existing Configurations as References#

The following examples, based on configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py, demonstrate three different model configuration styles.

Example 1: GR00T Architecture (LlavaVLA + Eagle + FlowMatchingHead)#

This is the actual configuration from configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py:

model = dict(
    type='LlavaVLA',
    pretrained_name_or_path=
    './checkpoints/GR00T-N1.5-3B',
    vlm_backbone=dict(
        type='EagleBackbone',
        vlm_path=
        'fluxvla/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=64,
        hidden_size=1024,
        input_embedding_dim=1536,
        num_layers=1,
        num_heads=4,
        num_inference_timesteps=4,
        num_steps=32,
        traj_length=10,
        action_dim=32,
        ori_action_dim=14),
    freeze_vlm_backbone=False,
    name_mapping={
        'vlm_backbone.vlm': 'backbone.eagle_model',
        'vla_head': 'action_head'
    },
    freeze_projector=False)

Architecture Analysis:

Component

Selection

Description

VLA

LlavaVLA

VLA with continuous action prediction supporting VLM backbone

Backbone

EagleBackbone

Eagle2.5-VL multimodal model

Head

FlowMatchingHead

Flow Matching generative action head (DiT)

Training Strategy

Full parameter fine-tuning

freeze_vlm_backbone=False

Example 2: PI0 Architecture (PI0FlowMatching + PaliGemma)#

model = dict(
    type='PI0FlowMatching',
    vlm_backbone=dict(
        type='PaliGemma',
        vlm_backbone_id='paligemma_3b_pt_224',
        vlm_config=dict(
            text_config=dict(
                num_hidden_layers=18,
                hidden_size=2048,
                intermediate_size=16384,
                num_attention_heads=8,
                num_key_value_heads=1))),
    proj_width=1024,
    n_action_steps=32,
    state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
    llm_expert=dict(
        type='GemmaLLMBackbone',
        llm_backbone_id='gemma-2b_causal',
        llm_max_length=256),
    freeze_vlm_backbone=True,
    pretrained_name_or_path='./checkpoints/pi0_pretrained',
    name_mapping={
        'vlm_backbone.vlm': 'vla.vlm',
        'llm_expert.llm': 'vla.language_model'
    })

Architecture Analysis:

Component

Selection

Description

VLA

PI0FlowMatching

PI0 Flow Matching architecture

Backbone

PaliGemma

PaliGemma 3B VLM

Expert

GemmaLLMBackbone

Additional Gemma LLM for action decoding

Projection

LinearProjector

Linear projection for states/actions

Training Strategy

Frozen VLM

freeze_vlm_backbone=True

Example 3: OpenVLA Architecture (Separate Vision + LLM)#

model = dict(
    type='OpenVLA',
    arch_specifier='no-align+fused-gelu-mlp',
    vision_backbone=dict(
        type='DinoSigLIPViTBackbone',
        vision_backbone_id='dinosiglip-vit-so-224px',
        dino_config=dict(image_size=224, patch_size=14),
        siglip_config=dict(image_size=224, patch_size=14)),
    llm_backbone=dict(
        type='LLaMa2LLMBackbone',
        llm_backbone_id='llama2-7b-pure_causal',
        llm_max_length=2048),
    projector=dict(
        type='FusedMLPProjector',
        fused_vision_dim=2176,
        llm_dim=4096),
    vla_head=dict(
        type='OpenVLAHead',
        norm_stats=None,
        vocab_size=32000),
    freeze_vision_backbone=False,
    freeze_llm_backbone=False,
    freeze_projector=False)

Architecture Analysis:

Component

Selection

Description

VLA

OpenVLA

Token-based action prediction VLA

Vision

DinoSigLIPViTBackbone

DINO + SigLIP fused vision encoder

Language

LLaMa2LLMBackbone

LLaMA-2 7B language model

Projection

FusedMLPProjector

Fused MLP projection

Head

OpenVLAHead

Token-discretized action head


Core Checklist#

After adding a new model component, verify each item in the following checklist:

  • [ ] The model file has been created in the correct directory (vlas/, backbones/vlms/, heads/, etc.)

  • [ ] The appropriate registration decorator has been applied (@VLAS.register_module(), @VLM_BACKBONES.register_module(), @HEADS.register_module(), etc.)

  • [ ] An import statement has been added in the corresponding __init__.py

  • [ ] The VLA model implements forward() and get_fsdp_wrapping_policy()

  • [ ] The VLM Backbone’s forward() returns a (features, attention_mask, extra) tuple

  • [ ] The VLA Head implements both forward() and predict_action() methods

  • [ ] forward() returns a dictionary containing the 'loss' key

  • [ ] predict_action() returns an action tensor

  • [ ] The type in the configuration file exactly matches the registered class name

  • [ ] The parameter names in the configuration file match the parameter names in the __init__ method

  • [ ] If using pretrained weights, name_mapping correctly maps the parameter names


Frequently Asked Questions#

Q: How do I decide whether to inherit from BaseVLA or an existing VLA?#

  • If your model is fundamentally different from existing VLA architectures (e.g., a new training paradigm), inherit from BaseVLA

  • If your model is a variant of an existing architecture (e.g., swapping the backbone or head), simply switch submodules via the configuration file — there is no need to create a new VLA class

  • If you need to modify the forward logic while reusing most of the code, inherit from the most similar VLA (e.g., LlavaVLA)

Q: What configuration changes are needed after adding a new Head?#

You only need to modify the type and corresponding parameters in the model.vla_head dictionary:

model = dict(
    type='LlavaVLA',                    # VLA 不变
    vlm_backbone=dict(...),             # backbone 不变
    vla_head=dict(
        type='MyActionHead',            # ← 换成你的 Head
        hidden_size=1024,
        state_dim=64,
        action_dim=32,
        input_embedding_dim=1536,
        traj_length=10,
        ori_action_dim=14),
    ...)

Q: What is inference_model?#

inference_model is an optional inference-specific model configuration. It is used when a different model setup is required during inference (e.g., replacing FlowMatchingHead with FlowMatchingInferenceHead to enable CUDA Graph acceleration):

model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)

If inference_model is not defined, the model configuration is used by default during inference.

Q: How do I ensure my Head is compatible with FSDP training?#

If your Head contains a large number of parameters, you should implement the get_fsdp_wrapping_policy() method to specify which submodules should be individually sharded by FSDP:

def get_fsdp_wrapping_policy(self):
    from functools import partial
    from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
    return partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={MyTransformerBlock})

If the Head has a small number of parameters, returning None is sufficient.

Q: What is the mapping direction of name_mapping?#

The format of name_mapping is {module name in FluxVLA: module name in pretrained checkpoint}:

name_mapping={
    'vlm_backbone.vlm': 'backbone.eagle_model',
    # 含义:FluxVLA 的 vlm_backbone.vlm.xxx ← checkpoint 的 backbone.eagle_model.xxx
}

During weight loading, the framework renames parameters with the backbone.eagle_model. prefix in the checkpoint to vlm_backbone.vlm. before loading them.