Adding Custom Models#
Overview#
In FluxVLA, a complete VLA model (Vision-Language-Action Model) is composed of multiple submodules:
Submodule |
Registry |
Purpose |
Examples |
|---|---|---|---|
VLA |
|
Top-level model that orchestrates all submodules |
|
VLM Backbone |
|
Vision-language multimodal understanding |
|
Vision Backbone |
|
Pure visual feature extraction |
|
LLM Backbone |
|
Pure language model |
|
Projector |
|
Feature space mapping (vision → language) |
|
VLA Head |
|
Action prediction head |
|
This tutorial provides guidance on how to add custom model components to FluxVLA — whether it is an entirely new VLA architecture, a new VLM backbone, or a new action prediction head.
Architecture Overview#
Model Inheritance Hierarchy#
BaseVLA (ABC)
├── OpenVLA # 基于 token 的动作预测
│ └── LlavaVLA # 支持 VLM backbone 的连续动作预测
└── PI0FlowMatching # Flow Matching 生成式动作预测
└── PI05FlowMatching # PI0.5 改进版
Directory Structure#
All model-related code is located under fluxvla/models/:
fluxvla/models/
├── __init__.py # 导出 backbones、heads、projectors、vlas
├── vlas/
│ ├── __init__.py # 导入所有 VLA(触发注册)
│ ├── base_vla.py # VLA 抽象基类
│ ├── open_vla.py # OpenVLA
│ ├── llava_vla.py # LlavaVLA(继承 OpenVLA)
│ ├── pi0_flowmatching.py # PI0FlowMatching
│ └── pi05_flowmatching.py # PI05FlowMatching(继承 PI0)
├── backbones/
│ ├── vlms/ # VLM backbones
│ │ ├── eagle.py # EagleBackbone
│ │ ├── paligemma.py # PaliGemma
│ │ └── qwen2_5_vl.py # QWen2_5VL
│ ├── visions/ # 视觉 backbones
│ │ ├── siglip_vit.py # SigLIPViTBackbone
│ │ └── dinosiglip_vit.py # DinoSigLIPViTBackbone
│ └── llms/ # LLM backbones
│ ├── llama2.py # LLaMa2LLMBackbone
│ ├── gemma.py # GemmaLLMBackbone
│ └── qwen2.py # Qwen2LLMBackbone
├── heads/
│ ├── flow_matching_head.py # FlowMatchingHead (DiT)
│ ├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
│ ├── llava_action_head.py # LlavaActionHead (Transformer)
│ └── openvla_head.py # OpenVLAHead (Token-based)
├── projectors/
│ ├── mlp_projector.py # MLPProjector
│ ├── linear_projector.py # LinearProjector
│ └── fused_projector.py # FusedMLPProjector
└── blocks/
└── cross_attention_dit.py # DiT Transformer block
Registration and Construction Mechanism#
FluxVLA employs a unified Registry pattern to manage all model components. Each component is registered via a decorator, and configuration files reference components through the type field.
Registration:
from fluxvla.engines.utils.root import VLAS
@VLAS.register_module()
class MyVLA(BaseVLA):
...
Construction (automatically performed within BaseVLA.__init__):
# BaseVLA.__init__ 中的子模块构建逻辑
if vlm_backbone is not None:
self.vlm_backbone = build_vlm_backbone_from_cfg(vlm_backbone)
if vision_backbone is not None:
self.vision_backbone = build_vision_backbone_from_cfg(vision_backbone)
if llm_backbone is not None:
self.llm_backbone = build_llm_backbone_from_cfg(llm_backbone)
if projector is not None:
self.projector = build_projector_from_cfg(projector)
if vla_head is not None:
self.vla_head = build_head_from_cfg(vla_head)
Model construction during training (BaseTrainRunner.__init__):
self.vla = build_vla_from_cfg(cfg.model) # 读取 config 中的 model 字典
Therefore, you only need to specify the type and the corresponding parameters in the configuration file; the framework will automatically look up and instantiate all submodules through the registry.
Adding a New VLA Model#
When you need to implement an entirely new VLA architecture (e.g., a different training paradigm or a different modality fusion strategy), you should add a new top-level VLA model.
Step 1: Create the Model File#
Create a new file under fluxvla/models/vlas/, for example my_vla.py:
from typing import Dict, Optional
import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from fluxvla.engines.utils.root import VLAS
from .base_vla import BaseVLA
@VLAS.register_module()
class MyVLA(BaseVLA):
"""自定义 VLA 模型示例。"""
def __init__(self,
my_custom_param: int = 256,
**kwargs) -> None:
super().__init__(**kwargs)
self.my_custom_param = my_custom_param
# BaseVLA 已经根据 config 自动构建了:
# self.vlm_backbone (如果 config 中指定了 vlm_backbone)
# self.vision_backbone (如果 config 中指定了 vision_backbone)
# self.llm_backbone (如果 config 中指定了 llm_backbone)
# self.projector (如果 config 中指定了 projector)
# self.vla_head (如果 config 中指定了 vla_head)
def forward(self,
images: torch.Tensor,
lang_tokens: torch.LongTensor,
lang_masks: torch.Tensor,
img_masks: torch.Tensor,
states: torch.Tensor,
actions: torch.Tensor,
action_masks: torch.Tensor,
**kwargs) -> Dict:
"""前向传播(训练时调用)。
Args:
images: 图像张量, shape=(B, N_cam, C, H, W)
lang_tokens: 语言 token, shape=(B, seq_len)
lang_masks: 语言掩码, shape=(B, seq_len)
img_masks: 图像掩码, shape=(B, N_cam)
states: 机器人状态, shape=(B, state_dim)
actions: 目标动作序列, shape=(B, traj_len, action_dim)
action_masks: 动作掩码, shape=(B, traj_len)
Returns:
dict: 包含 'loss' 键的字典
"""
# 1. 提取视觉-语言特征
features, attention_mask, _ = self.vlm_backbone(
images, lang_tokens, img_masks, lang_masks)
# 2. 通过动作头预测动作并计算损失
output = self.vla_head(
input_features=features,
states=states,
attention_mask=attention_mask,
actions=actions,
action_masks=action_masks)
return output # {'pred_actions': ..., 'loss': ...}
def predict_action(self,
images: torch.Tensor,
lang_tokens: torch.LongTensor,
lang_masks: torch.Tensor,
img_masks: torch.Tensor,
states: torch.Tensor,
**kwargs) -> torch.Tensor:
"""动作预测(推理时调用)。
Returns:
torch.Tensor: 预测的动作序列
"""
features, attention_mask, _ = self.vlm_backbone(
images, lang_tokens, img_masks, lang_masks)
actions = self.vla_head.predict_action(
input_features=features,
states=states,
attention_mask=attention_mask)
return actions
def get_fsdp_wrapping_policy(self):
"""返回 FSDP 包装策略,指定哪些子模块应单独分片。"""
from functools import partial
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
def should_wrap(module, recurse, *args, **kwargs):
if recurse:
return True
# 指定需要单独分片的模块类型
return isinstance(module, (
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
))
return partial(lambda_auto_wrap_policy, lambda_fn=should_wrap)
Step 2: Understand the BaseVLA Interface#
BaseVLA is the abstract base class for all VLA models. It provides common initialization logic and pretrained weight loading functionality. Subclasses must implement the following abstract methods:
Method |
Purpose |
Invocation Context |
|---|---|---|
|
Forward pass; computes the training loss |
Each step of the training loop |
|
Returns the FSDP sharding policy |
When using |
Functionality already provided by the base class (subclasses typically do not need to override these):
Automatic submodule construction: Automatically creates backbone, projector, and head based on the configuration
Pretrained weight loading: Loads pretrained models via
pretrained_name_or_pathandname_mappingModule freezing: Controls parameter freezing via
freeze_vlm_backbone,freeze_projector, etc.from_pretrained(): Loads weights from a checkpoint
Step 3: Register and Import#
Add the import statement in fluxvla/models/vlas/__init__.py:
from .llava_vla import LlavaVLA # noqa: F401, F403
from .my_vla import MyVLA # noqa: F401, F403 ← 新增
from .open_vla import OpenVLA # noqa: F401, F403
from .pi0_flowmatching import PI0FlowMatching # noqa: F401, F403
from .pi05_flowmatching import PI05FlowMatching # noqa: F401, F403
Step 4: Use in a Configuration File#
model = dict(
type='MyVLA', # ← 你注册的 VLA 类名
my_custom_param=256, # ← 自定义参数
pretrained_name_or_path='./checkpoints/my_pretrained',
vlm_backbone=dict(
type='EagleBackbone',
vlm_path='fluxvla/models/third_party_models/eagle2_hg_model'),
vla_head=dict(
type='FlowMatchingHead',
state_dim=64,
action_dim=32,
ori_action_dim=14,
...),
freeze_vlm_backbone=False,
freeze_projector=False)
Adding a New VLM Backbone#
When you need to integrate a new vision-language pretrained model (e.g., a new version of InternVL, LLaVA-Next, etc.) as a backbone, you should add a new VLM Backbone.
Step 1: Create the Backbone File#
Create a new file under fluxvla/models/backbones/vlms/, for example my_vlm.py:
from typing import Callable, Dict, Optional, Sequence, Type
import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from fluxvla.engines.utils.root import VLM_BACKBONES
@VLM_BACKBONES.register_module()
class MyVLMBackbone(nn.Module):
"""自定义 VLM Backbone。"""
def __init__(self,
vlm_path: str,
vlm_config: Optional[Dict] = None,
select_layer: int = -1,
dtype: str = 'float32'):
super().__init__()
# 加载预训练 VLM 模型
# 例如使用 transformers 的 AutoModel:
from transformers import AutoModel, AutoConfig
if vlm_config is not None:
config = AutoConfig.from_pretrained(vlm_path, **vlm_config)
else:
config = AutoConfig.from_pretrained(vlm_path)
self.vlm = AutoModel.from_pretrained(vlm_path, config=config)
self.select_layer = select_layer
self.config = config
def forward(self,
images: torch.Tensor,
lang_tokens: torch.LongTensor,
img_masks: torch.Tensor,
lang_masks: torch.Tensor
) -> tuple:
"""前向传播。
Args:
images: 图像张量, shape=(B, N_cam, C, H, W)
lang_tokens: 语言 token, shape=(B, seq_len)
img_masks: 图像掩码, shape=(B, N_cam)
lang_masks: 语言掩码, shape=(B, seq_len)
Returns:
tuple: (features, attention_mask, None)
- features: VLM 输出特征, shape=(B, total_seq_len, hidden_dim)
- attention_mask: 注意力掩码
"""
# 实现你的多模态特征提取逻辑
outputs = self.vlm(
input_ids=lang_tokens,
pixel_values=images,
attention_mask=lang_masks,
)
features = outputs.last_hidden_state
return features, lang_masks, None
def enable_gradient_checkpointing(self):
"""启用梯度检查点以节省显存。"""
self.vlm.gradient_checkpointing_enable()
def get_fsdp_wrapping_policy(self) -> Callable:
"""返回 FSDP 包装策略。"""
from functools import partial
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy, transformer_auto_wrap_policy)
# 获取 VLM 中的 Transformer 层类型
layer_cls = self._get_transformer_layer_cls()
return partial(
transformer_auto_wrap_policy,
transformer_layer_cls={layer_cls})
def _get_transformer_layer_cls(self) -> Type[nn.Module]:
"""返回 Transformer 层的类型,用于 FSDP 包装。"""
# 根据你的 VLM 架构返回对应的层类型
raise NotImplementedError
Step 2: Understand the Input/Output Conventions#
The forward method of the VLM Backbone must adhere to the following conventions to ensure compatibility with the upstream VLA model:
Input Parameters:
Parameter |
Type |
Description |
|---|---|---|
|
|
Image tensor; shape varies by implementation |
|
|
Language token IDs |
|
|
Image validity mask |
|
|
Language validity mask |
Return Values:
return (features, attention_mask, extra)
Return Value |
Description |
|---|---|
|
Multimodal fused features, shape=(B, seq_len, hidden_dim) |
|
Attention mask, shape=(B, seq_len) |
|
Additional information (typically |
Step 3: Register and Import#
Add the import statement in fluxvla/models/backbones/vlms/__init__.py:
from .eagle import EagleBackbone, EagleInferenceBackbone # noqa: F401, F403
from .my_vlm import MyVLMBackbone # noqa: F401, F403 ← 新增
from .paligemma import PaliGemma # noqa: F401, F403
from .qwen2_5_vl import QWen2_5VL # noqa: F401, F403
Adding a New VLA Head#
When you need to employ a new action prediction method (e.g., Diffusion Policy, VAE, VQ-VAE, etc.), you should add a new VLA Head.
Step 1: Create the Head File#
Create a new file under fluxvla/models/heads/, for example my_action_head.py:
from typing import Callable, Dict, Optional
import torch
import torch.nn as nn
from fluxvla.engines.utils.root import HEADS
@HEADS.register_module()
class MyActionHead(nn.Module):
"""自定义动作预测头。"""
def __init__(self,
hidden_size: int = 1024,
state_dim: int = 64,
action_dim: int = 32,
input_embedding_dim: int = 1536,
traj_length: int = 10,
ori_action_dim: Optional[int] = None):
super().__init__()
self.hidden_size = hidden_size
self.state_dim = state_dim
self.action_dim = action_dim
self.traj_length = traj_length
self.ori_action_dim = ori_action_dim or action_dim
# 状态编码器:将机器人状态映射到隐空间
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size))
# 动作解码器:从特征预测动作序列
self.action_decoder = nn.Sequential(
nn.Linear(input_embedding_dim + hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, traj_length * action_dim))
def forward(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
actions: torch.Tensor,
action_masks: torch.Tensor,
**kwargs) -> Dict:
"""前向传播(训练时调用)。
Args:
input_features: VLM 输出特征, shape=(B, seq_len, embed_dim)
states: 机器人状态, shape=(B, state_dim)
attention_mask: 注意力掩码, shape=(B, seq_len)
actions: 目标动作, shape=(B, traj_length, action_dim)
action_masks: 动作掩码, shape=(B, traj_length)
Returns:
dict: {'pred_actions': ..., 'loss': ...}
"""
B = input_features.shape[0]
# 1. 编码状态
state_emb = self.state_encoder(states) # (B, hidden_size)
# 2. 池化 VLM 特征(取最后一个有效 token)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths] # (B, embed_dim)
# 3. 拼接并预测动作
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined) # (B, traj_len * action_dim)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
# 4. 计算损失
loss = nn.functional.mse_loss(
pred_actions[action_masks.bool()],
actions[action_masks.bool()])
return {'pred_actions': pred_actions, 'loss': loss}
def predict_action(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs) -> torch.Tensor:
"""动作预测(推理时调用)。
Returns:
torch.Tensor: 预测动作序列, shape=(B, traj_length, action_dim)
"""
B = input_features.shape[0]
state_emb = self.state_encoder(states)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths]
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
return pred_actions
def get_fsdp_wrapping_policy(self) -> Callable:
"""返回 FSDP 包装策略。"""
return None # 简单模型无需特殊的 FSDP 分片策略
Step 2: Understand the Head Interface Conventions#
A VLA Head must implement two core methods:
Method |
Scenario |
Input |
Output |
|---|---|---|---|
|
Training |
Features + states + target actions |
|
|
Inference |
Features + states |
|
Design references for existing Heads:
Head Type |
Action Prediction Approach |
Characteristics |
|---|---|---|
|
Flow Matching (DiT) |
Generative, multi-step denoising, suitable for complex action distributions |
|
Transformer decoding |
Autoregressive prediction, supports variable-length trajectories |
|
Token discretization |
Token prediction based on VLM, no additional network required |
Step 3: Register and Import#
Add the import statement in fluxvla/models/heads/__init__.py:
from .flow_matching_head import FlowMatchingHead # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead # noqa: F401, F403
from .llava_action_head import LlavaActionHead # noqa: F401, F403
from .my_action_head import MyActionHead # noqa: F401, F403 ← 新增
from .openvla_head import OpenVLAHead # noqa: F401, F403
Adding a New Projector#
When the output dimension of your vision backbone does not match the input dimension of the language model, a Projector is required to perform feature space mapping.
from fluxvla.engines.utils.root import PROJECTORS
@PROJECTORS.register_module()
class MyProjector(nn.Module):
"""自定义特征投影器。"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.projector = nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim))
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
return self.projector(input_features)
After adding the import in fluxvla/models/projectors/__init__.py, the projector can be used via configuration:
model = dict(
type='OpenVLA',
projector=dict(
type='MyProjector',
in_dim=2176,
out_dim=4096),
...)
Composing Models in Configuration Files#
The model configuration in FluxVLA adopts a nested dictionary approach: the top level specifies the VLA type, and submodule configurations are nested within. BaseVLA automatically constructs all submodules based on the configuration.
Configuration Structure#
model = dict(
type='...', # VLA 类名 → VLAS 注册表
pretrained_name_or_path='...', # 预训练权重路径
vlm_backbone=dict( # VLM backbone → VLM_BACKBONES 注册表
type='...',
...),
# 或者使用 vision_backbone + llm_backbone + projector 的组合:
# vision_backbone=dict(type='...', ...), → VISION_BACKBONES
# llm_backbone=dict(type='...', ...), → LLM_BACKBONES
# projector=dict(type='...', ...), → PROJECTORS
vla_head=dict( # VLA Head → HEADS 注册表
type='...',
...),
freeze_vlm_backbone=False,
freeze_projector=False,
name_mapping={...}, # 预训练权重的键名映射
)
Two Backbone Composition Modes#
FluxVLA supports two backbone composition modes, which are mutually exclusive:
Mode 1: Using a VLM Backbone (Recommended)
A VLM backbone is an end-to-end vision-language model (e.g., Eagle, PaliGemma) that has a built-in vision encoder and language model:
model = dict(
type='LlavaVLA',
vlm_backbone=dict(type='EagleBackbone', vlm_path='...'),
vla_head=dict(type='FlowMatchingHead', ...),
...)
Mode 2: Using Vision + LLM + Projector
Specify the vision encoder, language model, and projection layer separately:
model = dict(
type='OpenVLA',
vision_backbone=dict(type='DinoSigLIPViTBackbone', ...),
llm_backbone=dict(type='LLaMa2LLMBackbone', ...),
projector=dict(type='FusedMLPProjector', ...),
vla_head=dict(type='OpenVLAHead', ...),
...)
name_mapping for Pretrained Weight Mapping#
When the parameter names in the pretrained model differ from the module names in FluxVLA, use name_mapping to establish the mapping relationship:
name_mapping={
'vlm_backbone.vlm': 'backbone.eagle_model',
'vla_head': 'action_head'
}
The above mapping indicates that parameters under backbone.eagle_model.* in the pretrained checkpoint will be loaded into vlm_backbone.vlm.*, and parameters under action_head.* will be loaded into vla_head.*.
Complete Examples: Using Existing Configurations as References#
The following examples, based on configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py, demonstrate three different model configuration styles.
Example 1: GR00T Architecture (LlavaVLA + Eagle + FlowMatchingHead)#
This is the actual configuration from configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py:
model = dict(
type='LlavaVLA',
pretrained_name_or_path=
'./checkpoints/GR00T-N1.5-3B',
vlm_backbone=dict(
type='EagleBackbone',
vlm_path=
'fluxvla/models/third_party_models/eagle2_hg_model'),
vla_head=dict(
type='FlowMatchingHead',
state_dim=64,
hidden_size=1024,
input_embedding_dim=1536,
num_layers=1,
num_heads=4,
num_inference_timesteps=4,
num_steps=32,
traj_length=10,
action_dim=32,
ori_action_dim=14),
freeze_vlm_backbone=False,
name_mapping={
'vlm_backbone.vlm': 'backbone.eagle_model',
'vla_head': 'action_head'
},
freeze_projector=False)
Architecture Analysis:
Component |
Selection |
Description |
|---|---|---|
VLA |
|
VLA with continuous action prediction supporting VLM backbone |
Backbone |
|
Eagle2.5-VL multimodal model |
Head |
|
Flow Matching generative action head (DiT) |
Training Strategy |
Full parameter fine-tuning |
|
Example 2: PI0 Architecture (PI0FlowMatching + PaliGemma)#
model = dict(
type='PI0FlowMatching',
vlm_backbone=dict(
type='PaliGemma',
vlm_backbone_id='paligemma_3b_pt_224',
vlm_config=dict(
text_config=dict(
num_hidden_layers=18,
hidden_size=2048,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1))),
proj_width=1024,
n_action_steps=32,
state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
llm_expert=dict(
type='GemmaLLMBackbone',
llm_backbone_id='gemma-2b_causal',
llm_max_length=256),
freeze_vlm_backbone=True,
pretrained_name_or_path='./checkpoints/pi0_pretrained',
name_mapping={
'vlm_backbone.vlm': 'vla.vlm',
'llm_expert.llm': 'vla.language_model'
})
Architecture Analysis:
Component |
Selection |
Description |
|---|---|---|
VLA |
|
PI0 Flow Matching architecture |
Backbone |
|
PaliGemma 3B VLM |
Expert |
|
Additional Gemma LLM for action decoding |
Projection |
|
Linear projection for states/actions |
Training Strategy |
Frozen VLM |
|
Example 3: OpenVLA Architecture (Separate Vision + LLM)#
model = dict(
type='OpenVLA',
arch_specifier='no-align+fused-gelu-mlp',
vision_backbone=dict(
type='DinoSigLIPViTBackbone',
vision_backbone_id='dinosiglip-vit-so-224px',
dino_config=dict(image_size=224, patch_size=14),
siglip_config=dict(image_size=224, patch_size=14)),
llm_backbone=dict(
type='LLaMa2LLMBackbone',
llm_backbone_id='llama2-7b-pure_causal',
llm_max_length=2048),
projector=dict(
type='FusedMLPProjector',
fused_vision_dim=2176,
llm_dim=4096),
vla_head=dict(
type='OpenVLAHead',
norm_stats=None,
vocab_size=32000),
freeze_vision_backbone=False,
freeze_llm_backbone=False,
freeze_projector=False)
Architecture Analysis:
Component |
Selection |
Description |
|---|---|---|
VLA |
|
Token-based action prediction VLA |
Vision |
|
DINO + SigLIP fused vision encoder |
Language |
|
LLaMA-2 7B language model |
Projection |
|
Fused MLP projection |
Head |
|
Token-discretized action head |
Core Checklist#
After adding a new model component, verify each item in the following checklist:
[ ] The model file has been created in the correct directory (
vlas/,backbones/vlms/,heads/, etc.)[ ] The appropriate registration decorator has been applied (
@VLAS.register_module(),@VLM_BACKBONES.register_module(),@HEADS.register_module(), etc.)[ ] An import statement has been added in the corresponding
__init__.py[ ] The VLA model implements
forward()andget_fsdp_wrapping_policy()[ ] The VLM Backbone’s
forward()returns a(features, attention_mask, extra)tuple[ ] The VLA Head implements both
forward()andpredict_action()methods[ ]
forward()returns a dictionary containing the'loss'key[ ]
predict_action()returns an action tensor[ ] The
typein the configuration file exactly matches the registered class name[ ] The parameter names in the configuration file match the parameter names in the
__init__method[ ] If using pretrained weights,
name_mappingcorrectly maps the parameter names
Frequently Asked Questions#
Q: How do I decide whether to inherit from BaseVLA or an existing VLA?#
If your model is fundamentally different from existing VLA architectures (e.g., a new training paradigm), inherit from
BaseVLAIf your model is a variant of an existing architecture (e.g., swapping the backbone or head), simply switch submodules via the configuration file — there is no need to create a new VLA class
If you need to modify the
forwardlogic while reusing most of the code, inherit from the most similar VLA (e.g.,LlavaVLA)
Q: What configuration changes are needed after adding a new Head?#
You only need to modify the type and corresponding parameters in the model.vla_head dictionary:
model = dict(
type='LlavaVLA', # VLA 不变
vlm_backbone=dict(...), # backbone 不变
vla_head=dict(
type='MyActionHead', # ← 换成你的 Head
hidden_size=1024,
state_dim=64,
action_dim=32,
input_embedding_dim=1536,
traj_length=10,
ori_action_dim=14),
...)
Q: What is inference_model?#
inference_model is an optional inference-specific model configuration. It is used when a different model setup is required during inference (e.g., replacing FlowMatchingHead with FlowMatchingInferenceHead to enable CUDA Graph acceleration):
model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)
If inference_model is not defined, the model configuration is used by default during inference.
Q: How do I ensure my Head is compatible with FSDP training?#
If your Head contains a large number of parameters, you should implement the get_fsdp_wrapping_policy() method to specify which submodules should be individually sharded by FSDP:
def get_fsdp_wrapping_policy(self):
from functools import partial
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
return partial(
transformer_auto_wrap_policy,
transformer_layer_cls={MyTransformerBlock})
If the Head has a small number of parameters, returning None is sufficient.
Q: What is the mapping direction of name_mapping?#
The format of name_mapping is {module name in FluxVLA: module name in pretrained checkpoint}:
name_mapping={
'vlm_backbone.vlm': 'backbone.eagle_model',
# 含义:FluxVLA 的 vlm_backbone.vlm.xxx ← checkpoint 的 backbone.eagle_model.xxx
}
During weight loading, the framework renames parameters with the backbone.eagle_model. prefix in the checkpoint to vlm_backbone.vlm. before loading them.