添加自定义模块#
概述#
FluxVLA 采用注册表(Registry)机制管理各类模块,支持灵活地添加自定义实现。可扩展的模块类型包括:
注册表 |
说明 |
代码位置 |
|---|---|---|
|
动作预测头(如 Diffusion Policy、VAE 等) |
|
|
VLA 模型 |
|
|
语言模型骨干网络 |
|
|
视觉骨干网络 |
|
|
视觉语言模型骨干网络 |
|
|
特征投影器 |
|
|
数据集 |
|
|
数据变换 |
|
|
数据整理器 |
|
|
分词器 |
|
|
训练/推理运行器 |
|
|
机器人算子 |
|
|
评估指标 |
|
|
处理器 |
|
所有模块通过对应的注册表进行管理,可在配置文件中通过 type 字段引用。各类模块的添加流程一致:创建文件 → 注册模块 → 导入 → 配置引用。
本文以 VLA Head(动作预测头)为例,演示如何添加自定义模块。
Head 相关代码位于 fluxvla/models/heads/:
fluxvla/models/heads/
├── flow_matching_head.py # FlowMatchingHead (DiT)
├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
├── llava_action_head.py # LlavaActionHead (Transformer)
└── openvla_head.py # OpenVLAHead (Token-based)
现有 Head 的设计参考:
Head 类型 |
动作预测方式 |
特点 |
|---|---|---|
|
Flow Matching (DiT) |
生成式,多步去噪,适合复杂动作分布 |
|
Transformer 解码 |
自回归预测,支持可变长度轨迹 |
|
Token 离散化 |
基于 VLM 的 token 预测,无需额外网络 |
添加新的 VLA Head#
第一步:创建 Head 文件#
在 fluxvla/models/heads/ 下创建新文件,例如 my_action_head.py:
from typing import Callable, Dict, Optional
import torch
import torch.nn as nn
from fluxvla.engines.utils.root import HEADS
@HEADS.register_module()
class MyActionHead(nn.Module):
"""自定义动作预测头。"""
def __init__(self,
hidden_size: int = 1024,
state_dim: int = 64,
action_dim: int = 32,
input_embedding_dim: int = 1536,
traj_length: int = 10,
ori_action_dim: Optional[int] = None):
super().__init__()
self.hidden_size = hidden_size
self.state_dim = state_dim
self.action_dim = action_dim
self.traj_length = traj_length
self.ori_action_dim = ori_action_dim or action_dim
# 状态编码器:将机器人状态映射到隐空间
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size))
# 动作解码器:从特征预测动作序列
self.action_decoder = nn.Sequential(
nn.Linear(input_embedding_dim + hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, traj_length * action_dim))
def forward(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
actions: torch.Tensor,
action_masks: torch.Tensor,
**kwargs) -> Dict:
"""前向传播(训练时调用)。
Args:
input_features: VLM 输出特征, shape=(B, seq_len, embed_dim)
states: 机器人状态, shape=(B, state_dim)
attention_mask: 注意力掩码, shape=(B, seq_len)
actions: 目标动作, shape=(B, traj_length, action_dim)
action_masks: 动作掩码, shape=(B, traj_length)
Returns:
dict: {'pred_actions': ..., 'loss': ...}
"""
B = input_features.shape[0]
# 1. 编码状态
state_emb = self.state_encoder(states) # (B, hidden_size)
# 2. 池化 VLM 特征(取最后一个有效 token)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths] # (B, embed_dim)
# 3. 拼接并预测动作
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined) # (B, traj_len * action_dim)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
# 4. 计算损失
loss = nn.functional.mse_loss(
pred_actions[action_masks.bool()],
actions[action_masks.bool()])
return {'pred_actions': pred_actions, 'loss': loss}
def predict_action(self,
input_features: torch.Tensor,
states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs) -> torch.Tensor:
"""动作预测(推理时调用)。
Returns:
torch.Tensor: 预测动作序列, shape=(B, traj_length, action_dim)
"""
B = input_features.shape[0]
state_emb = self.state_encoder(states)
seq_lengths = attention_mask.sum(dim=1).long() - 1
pooled = input_features[range(B), seq_lengths]
combined = torch.cat([pooled, state_emb], dim=-1)
pred_flat = self.action_decoder(combined)
pred_actions = pred_flat.view(B, self.traj_length, self.action_dim)
return pred_actions
def get_fsdp_wrapping_policy(self) -> Callable:
"""返回 FSDP 包装策略。"""
return None # 简单模型无需特殊的 FSDP 分片策略
第二步:理解 Head 接口约定#
VLA Head 需要实现两个核心方法:
方法 |
场景 |
输入 |
输出 |
|---|---|---|---|
|
训练 |
特征 + 状态 + 目标动作 |
|
|
推理 |
特征 + 状态 |
|
forward() 输入参数:
参数 |
类型 |
说明 |
|---|---|---|
|
|
VLM 输出特征, shape=(B, seq_len, embed_dim) |
|
|
机器人状态, shape=(B, state_dim) |
|
|
注意力掩码, shape=(B, seq_len) |
|
|
目标动作, shape=(B, traj_length, action_dim) |
|
|
动作掩码, shape=(B, traj_length) |
forward() 返回值:必须返回一个包含 'loss' 键的字典,例如 {'pred_actions': Tensor, 'loss': Tensor}。
predict_action() 返回值:返回预测的动作张量, shape=(B, traj_length, action_dim)。
第三步:注册并导入#
在 fluxvla/models/heads/__init__.py 中添加导入:
from .flow_matching_head import FlowMatchingHead # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead # noqa: F401, F403
from .llava_action_head import LlavaActionHead # noqa: F401, F403
from .my_action_head import MyActionHead # noqa: F401, F403 ← 新增
from .openvla_head import OpenVLAHead # noqa: F401, F403
第四步:在配置文件中使用#
只需修改 model.vla_head 字典中的 type 和对应参数:
model = dict(
type='LlavaVLA', # VLA 不变
vlm_backbone=dict(...), # backbone 不变
vla_head=dict(
type='MyActionHead', # ← 换成你的 Head
hidden_size=1024,
state_dim=64,
action_dim=32,
input_embedding_dim=1536,
traj_length=10,
ori_action_dim=14),
...)
核心检查清单#
添加新的 VLA Head 后,请按以下清单逐项确认:
Head 文件已创建在
fluxvla/models/heads/目录下使用了
@HEADS.register_module()装饰器注册在
fluxvla/models/heads/__init__.py中添加了 import 语句实现了
forward()方法,返回包含'loss'键的字典实现了
predict_action()方法,返回动作张量配置文件中的
type与注册类名完全一致配置文件中的参数名与
__init__方法的参数名一致
常见问题#
Q:添加新 Head 后模型配置需要改什么?#
只需要修改 model.vla_head 字典中的 type 和对应参数,VLA 模型和 Backbone 不需要改动。
Q:inference_model 是什么?#
inference_model 是可选的推理专用模型配置。当推理时需要不同的模型设置(例如使用 FlowMatchingInferenceHead 替代 FlowMatchingHead 以启用 CUDA Graph 加速)时使用:
model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)
如果未定义 inference_model,推理时默认使用 model 的配置。
Q:如何确保我的 Head 与 FSDP 训练兼容?#
如果你的 Head 包含大量参数,需要实现 get_fsdp_wrapping_policy() 方法,指定哪些子模块应被 FSDP 单独分片:
def get_fsdp_wrapping_policy(self):
from functools import partial
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
return partial(
transformer_auto_wrap_policy,
transformer_layer_cls={MyTransformerBlock})
如果 Head 参数量较小,返回 None 即可。