添加自定义模型#

概述#

在 FluxVLA 中,一个完整的 VLA 模型(Vision-Language-Action Model)由多个子模块组合而成:

子模块

注册表

作用

示例

VLA

VLAS

顶层模型,编排所有子模块

LlavaVLAOpenVLAPI0FlowMatching

VLM Backbone

VLM_BACKBONES

视觉-语言多模态理解

EagleBackbonePaliGemmaQWen2_5VL

Vision Backbone

VISION_BACKBONES

纯视觉特征提取

SigLIPViTBackboneDinoSigLIPViTBackbone

LLM Backbone

LLM_BACKBONES

纯语言模型

LLaMa2LLMBackboneGemmaLLMBackbone

Projector

PROJECTORS

特征空间映射(视觉 → 语言)

MLPProjectorLinearProjectorFusedMLPProjector

VLA Head

HEADS

动作预测头

FlowMatchingHeadLlavaActionHeadOpenVLAHead

本教程将指导你如何为 FluxVLA 添加自定义的模型组件——无论是一个全新的 VLA 架构、一种新的 VLM Backbone,还是一种新的动作预测头。


架构概览#

模型继承关系#

BaseVLA (ABC)
├── OpenVLA                  # 基于 token 的动作预测
│   └── LlavaVLA             # 支持 VLM backbone 的连续动作预测
└── PI0FlowMatching          # Flow Matching 生成式动作预测
    └── PI05FlowMatching     # PI0.5 改进版

目录结构#

所有模型相关代码位于 fluxvla/models/

fluxvla/models/
├── __init__.py              # 导出 backbones、heads、projectors、vlas
├── vlas/
│   ├── __init__.py          # 导入所有 VLA(触发注册)
│   ├── base_vla.py          # VLA 抽象基类
│   ├── open_vla.py          # OpenVLA
│   ├── llava_vla.py         # LlavaVLA(继承 OpenVLA)
│   ├── pi0_flowmatching.py  # PI0FlowMatching
│   └── pi05_flowmatching.py # PI05FlowMatching(继承 PI0)
├── backbones/
│   ├── vlms/                # VLM backbones
│   │   ├── eagle.py         # EagleBackbone
│   │   ├── paligemma.py     # PaliGemma
│   │   └── qwen2_5_vl.py   # QWen2_5VL
│   ├── visions/             # 视觉 backbones
│   │   ├── siglip_vit.py    # SigLIPViTBackbone
│   │   └── dinosiglip_vit.py # DinoSigLIPViTBackbone
│   └── llms/                # LLM backbones
│       ├── llama2.py        # LLaMa2LLMBackbone
│       ├── gemma.py         # GemmaLLMBackbone
│       └── qwen2.py         # Qwen2LLMBackbone
├── heads/
│   ├── flow_matching_head.py         # FlowMatchingHead (DiT)
│   ├── flow_matching_inference_head.py # FlowMatchingInferenceHead (CUDA Graph)
│   ├── llava_action_head.py          # LlavaActionHead (Transformer)
│   └── openvla_head.py              # OpenVLAHead (Token-based)
├── projectors/
│   ├── mlp_projector.py     # MLPProjector
│   ├── linear_projector.py  # LinearProjector
│   └── fused_projector.py   # FusedMLPProjector
└── blocks/
    └── cross_attention_dit.py # DiT Transformer block

注册与构建机制#

FluxVLA 使用统一的 Registry 模式管理所有模型组件。每个组件通过装饰器注册,配置文件通过 type 字段引用。

注册

from fluxvla.engines.utils.root import VLAS

@VLAS.register_module()
class MyVLA(BaseVLA):
    ...

构建(自动在 BaseVLA.__init__ 中完成):

# BaseVLA.__init__ 中的子模块构建逻辑
if vlm_backbone is not None:
    self.vlm_backbone = build_vlm_backbone_from_cfg(vlm_backbone)
if vision_backbone is not None:
    self.vision_backbone = build_vision_backbone_from_cfg(vision_backbone)
if llm_backbone is not None:
    self.llm_backbone = build_llm_backbone_from_cfg(llm_backbone)
if projector is not None:
    self.projector = build_projector_from_cfg(projector)
if vla_head is not None:
    self.vla_head = build_head_from_cfg(vla_head)

训练时的模型构建BaseTrainRunner.__init__):

self.vla = build_vla_from_cfg(cfg.model)   # 读取 config 中的 model 字典

因此,你只需要在配置文件中填写 type 和对应参数,框架会自动通过注册表查找并实例化所有子模块。


添加新的 VLA 模型#

当你需要实现一种全新的 VLA 架构时(例如不同的训练范式、不同的模态融合方式),需要添加一个新的顶层 VLA 模型。

第一步:创建模型文件#

fluxvla/models/vlas/ 下创建新文件,例如 my_vla.py

from typing import Dict, Optional

import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast

from fluxvla.engines.utils.root import VLAS
from .base_vla import BaseVLA


@VLAS.register_module()
class MyVLA(BaseVLA):
    """自定义 VLA 模型示例。"""

    def __init__(self,
                 my_custom_param: int = 256,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        self.my_custom_param = my_custom_param
        # BaseVLA 已经根据 config 自动构建了:
        #   self.vlm_backbone  (如果 config 中指定了 vlm_backbone)
        #   self.vision_backbone (如果 config 中指定了 vision_backbone)
        #   self.llm_backbone  (如果 config 中指定了 llm_backbone)
        #   self.projector     (如果 config 中指定了 projector)
        #   self.vla_head      (如果 config 中指定了 vla_head)

    def forward(self,
                images: torch.Tensor,
                lang_tokens: torch.LongTensor,
                lang_masks: torch.Tensor,
                img_masks: torch.Tensor,
                states: torch.Tensor,
                actions: torch.Tensor,
                action_masks: torch.Tensor,
                **kwargs) -> Dict:
        """前向传播(训练时调用)。

        Args:
            images: 图像张量, shape=(B, N_cam, C, H, W)
            lang_tokens: 语言 token, shape=(B, seq_len)
            lang_masks: 语言掩码, shape=(B, seq_len)
            img_masks: 图像掩码, shape=(B, N_cam)
            states: 机器人状态, shape=(B, state_dim)
            actions: 目标动作序列, shape=(B, traj_len, action_dim)
            action_masks: 动作掩码, shape=(B, traj_len)

        Returns:
            dict: 包含 'loss' 键的字典
        """
        # 1. 提取视觉-语言特征
        features, attention_mask, _ = self.vlm_backbone(
            images, lang_tokens, img_masks, lang_masks)

        # 2. 通过动作头预测动作并计算损失
        output = self.vla_head(
            input_features=features,
            states=states,
            attention_mask=attention_mask,
            actions=actions,
            action_masks=action_masks)

        return output  # {'pred_actions': ..., 'loss': ...}

    def predict_action(self,
                       images: torch.Tensor,
                       lang_tokens: torch.LongTensor,
                       lang_masks: torch.Tensor,
                       img_masks: torch.Tensor,
                       states: torch.Tensor,
                       **kwargs) -> torch.Tensor:
        """动作预测(推理时调用)。

        Returns:
            torch.Tensor: 预测的动作序列
        """
        features, attention_mask, _ = self.vlm_backbone(
            images, lang_tokens, img_masks, lang_masks)

        actions = self.vla_head.predict_action(
            input_features=features,
            states=states,
            attention_mask=attention_mask)

        return actions

    def get_fsdp_wrapping_policy(self):
        """返回 FSDP 包装策略,指定哪些子模块应单独分片。"""
        from functools import partial
        from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy

        def should_wrap(module, recurse, *args, **kwargs):
            if recurse:
                return True
            # 指定需要单独分片的模块类型
            return isinstance(module, (
                nn.TransformerEncoderLayer,
                nn.TransformerDecoderLayer,
            ))

        return partial(lambda_auto_wrap_policy, lambda_fn=should_wrap)

第二步:理解 BaseVLA 接口#

BaseVLA 是所有 VLA 模型的抽象基类,提供了通用的初始化逻辑和预训练权重加载功能。子类必须实现以下抽象方法:

方法

作用

调用时机

forward(...)

前向传播,计算训练损失

训练循环的每个 step

get_fsdp_wrapping_policy()

返回 FSDP 分片策略

使用 FSDPTrainRunner

基类已经提供的功能(子类通常不需要重写):

  • 子模块自动构建:根据 config 自动创建 backbone、projector、head

  • 预训练权重加载:通过 pretrained_name_or_pathname_mapping 加载预训练模型

  • 模块冻结:通过 freeze_vlm_backbonefreeze_projector 等控制参数冻结

  • from_pretrained():从 checkpoint 加载权重

第三步:注册并导入#

fluxvla/models/vlas/__init__.py 中添加导入:

from .llava_vla import LlavaVLA  # noqa: F401, F403
from .my_vla import MyVLA  # noqa: F401, F403               ← 新增
from .open_vla import OpenVLA  # noqa: F401, F403
from .pi0_flowmatching import PI0FlowMatching  # noqa: F401, F403
from .pi05_flowmatching import PI05FlowMatching  # noqa: F401, F403

第四步:在配置文件中使用#

model = dict(
    type='MyVLA',                    # ← 你注册的 VLA 类名
    my_custom_param=256,             # ← 自定义参数
    pretrained_name_or_path='./checkpoints/my_pretrained',
    vlm_backbone=dict(
        type='EagleBackbone',
        vlm_path='fluxvla/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=64,
        action_dim=32,
        ori_action_dim=14,
        ...),
    freeze_vlm_backbone=False,
    freeze_projector=False)

添加新的 VLM Backbone#

当你需要接入一种新的视觉-语言预训练模型(如新版 InternVL、LLaVA-Next 等)作为 backbone 时,需要添加一个新的 VLM Backbone。

第一步:创建 Backbone 文件#

fluxvla/models/backbones/vlms/ 下创建新文件,例如 my_vlm.py

from typing import Callable, Dict, Optional, Sequence, Type

import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithPast

from fluxvla.engines.utils.root import VLM_BACKBONES


@VLM_BACKBONES.register_module()
class MyVLMBackbone(nn.Module):
    """自定义 VLM Backbone。"""

    def __init__(self,
                 vlm_path: str,
                 vlm_config: Optional[Dict] = None,
                 select_layer: int = -1,
                 dtype: str = 'float32'):
        super().__init__()
        # 加载预训练 VLM 模型
        # 例如使用 transformers 的 AutoModel:
        from transformers import AutoModel, AutoConfig
        if vlm_config is not None:
            config = AutoConfig.from_pretrained(vlm_path, **vlm_config)
        else:
            config = AutoConfig.from_pretrained(vlm_path)
        self.vlm = AutoModel.from_pretrained(vlm_path, config=config)
        self.select_layer = select_layer
        self.config = config

    def forward(self,
                images: torch.Tensor,
                lang_tokens: torch.LongTensor,
                img_masks: torch.Tensor,
                lang_masks: torch.Tensor
                ) -> tuple:
        """前向传播。

        Args:
            images: 图像张量, shape=(B, N_cam, C, H, W)
            lang_tokens: 语言 token, shape=(B, seq_len)
            img_masks: 图像掩码, shape=(B, N_cam)
            lang_masks: 语言掩码, shape=(B, seq_len)

        Returns:
            tuple: (features, attention_mask, None)
                - features: VLM 输出特征, shape=(B, total_seq_len, hidden_dim)
                - attention_mask: 注意力掩码
        """
        # 实现你的多模态特征提取逻辑
        outputs = self.vlm(
            input_ids=lang_tokens,
            pixel_values=images,
            attention_mask=lang_masks,
        )
        features = outputs.last_hidden_state
        return features, lang_masks, None

    def enable_gradient_checkpointing(self):
        """启用梯度检查点以节省显存。"""
        self.vlm.gradient_checkpointing_enable()

    def get_fsdp_wrapping_policy(self) -> Callable:
        """返回 FSDP 包装策略。"""
        from functools import partial
        from torch.distributed.fsdp.wrap import (
            lambda_auto_wrap_policy, transformer_auto_wrap_policy)

        # 获取 VLM 中的 Transformer 层类型
        layer_cls = self._get_transformer_layer_cls()
        return partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={layer_cls})

    def _get_transformer_layer_cls(self) -> Type[nn.Module]:
        """返回 Transformer 层的类型,用于 FSDP 包装。"""
        # 根据你的 VLM 架构返回对应的层类型
        raise NotImplementedError

第二步:理解输入输出约定#

VLM Backbone 的 forward 方法需要遵循以下约定,以便与上层 VLA 模型兼容:

输入参数

参数

类型

说明

images

torch.Tensor

图像张量,shape 因实现而异

lang_tokens

torch.LongTensor

语言 token IDs

img_masks

torch.Tensor

图像有效位掩码

lang_masks

torch.Tensor

语言有效位掩码

返回值

return (features, attention_mask, extra)

返回值

说明

features

多模态融合特征, shape=(B, seq_len, hidden_dim)

attention_mask

注意力掩码, shape=(B, seq_len)

extra

额外信息(一般为 None

第三步:注册并导入#

fluxvla/models/backbones/vlms/__init__.py 中添加导入:

from .eagle import EagleBackbone, EagleInferenceBackbone  # noqa: F401, F403
from .my_vlm import MyVLMBackbone  # noqa: F401, F403        ← 新增
from .paligemma import PaliGemma  # noqa: F401, F403
from .qwen2_5_vl import QWen2_5VL  # noqa: F401, F403

添加新的 VLA Head#

当你需要使用一种新的动作预测方法时(例如 Diffusion Policy、VAE、VQ-VAE 等),需要添加一个新的 VLA Head。

第二步:理解 Head 接口约定#

VLA Head 需要实现两个核心方法:

方法

场景

输入

输出

forward(...)

训练

特征 + 状态 + 目标动作

{'pred_actions': Tensor, 'loss': Tensor}

predict_action(...)

推理

特征 + 状态

Tensor (预测的动作序列)

现有 Head 的设计参考:

Head 类型

动作预测方式

特点

FlowMatchingHead

Flow Matching (DiT)

生成式,多步去噪,适合复杂动作分布

LlavaActionHead

Transformer 解码

自回归预测,支持可变长度轨迹

OpenVLAHead

Token 离散化

基于 VLM 的 token 预测,无需额外网络

第三步:注册并导入#

fluxvla/models/heads/__init__.py 中添加导入:

from .flow_matching_head import FlowMatchingHead  # noqa: F401, F403
from .flow_matching_inference_head import FlowMatchingInferenceHead  # noqa: F401, F403
from .llava_action_head import LlavaActionHead  # noqa: F401, F403
from .my_action_head import MyActionHead  # noqa: F401, F403   ← 新增
from .openvla_head import OpenVLAHead  # noqa: F401, F403

添加新的 Projector#

当你的视觉 backbone 输出维度与语言模型输入维度不匹配时,需要一个 Projector 进行特征空间映射。

from fluxvla.engines.utils.root import PROJECTORS

@PROJECTORS.register_module()
class MyProjector(nn.Module):
    """自定义特征投影器。"""

    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.projector = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim))

    def forward(self, input_features: torch.Tensor) -> torch.Tensor:
        return self.projector(input_features)

fluxvla/models/projectors/__init__.py 中添加导入后即可通过配置使用:

model = dict(
    type='OpenVLA',
    projector=dict(
        type='MyProjector',
        in_dim=2176,
        out_dim=4096),
    ...)

在配置文件中组合模型#

FluxVLA 的模型配置采用嵌套字典的方式,顶层指定 VLA 类型,内部嵌套子模块配置。BaseVLA 会自动根据配置构建所有子模块。

配置结构#

model = dict(
    type='...',                     # VLA 类名 → VLAS 注册表
    pretrained_name_or_path='...',  # 预训练权重路径
    vlm_backbone=dict(              # VLM backbone → VLM_BACKBONES 注册表
        type='...',
        ...),
    # 或者使用 vision_backbone + llm_backbone + projector 的组合:
    # vision_backbone=dict(type='...', ...),   → VISION_BACKBONES
    # llm_backbone=dict(type='...', ...),      → LLM_BACKBONES
    # projector=dict(type='...', ...),         → PROJECTORS
    vla_head=dict(                  # VLA Head → HEADS 注册表
        type='...',
        ...),
    freeze_vlm_backbone=False,
    freeze_projector=False,
    name_mapping={...},             # 预训练权重的键名映射
)

两种 Backbone 组合模式#

FluxVLA 支持两种 backbone 组合方式,不可混用:

模式一:使用 VLM Backbone(推荐)

VLM backbone 是一个端到端的视觉-语言模型(如 Eagle、PaliGemma),已经内置了视觉编码器和语言模型:

model = dict(
    type='LlavaVLA',
    vlm_backbone=dict(type='EagleBackbone', vlm_path='...'),
    vla_head=dict(type='FlowMatchingHead', ...),
    ...)

模式二:使用 Vision + LLM + Projector

分别指定视觉编码器、语言模型和投影层:

model = dict(
    type='OpenVLA',
    vision_backbone=dict(type='DinoSigLIPViTBackbone', ...),
    llm_backbone=dict(type='LLaMa2LLMBackbone', ...),
    projector=dict(type='FusedMLPProjector', ...),
    vla_head=dict(type='OpenVLAHead', ...),
    ...)

name_mapping 预训练权重映射#

当预训练模型的参数名称与 FluxVLA 中的模块名称不同时,使用 name_mapping 建立映射关系:

name_mapping={
    'vlm_backbone.vlm': 'backbone.eagle_model',
    'vla_head': 'action_head'
}

上述映射表示:预训练 checkpoint 中 backbone.eagle_model.* 的参数会被加载到 vlm_backbone.vlm.*action_head.* 的参数会被加载到 vla_head.*


完整示例:以现有配置为参考#

下面以 configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py 为参考,展示三种不同的模型配置风格。

示例 1:GR00T 架构(LlavaVLA + Eagle + FlowMatchingHead)#

这是 configs/gr00t/gr00t_eagle_3b_aloha_full_finetune.py 中的实际配置:

model = dict(
    type='LlavaVLA',
    pretrained_name_or_path=
    './checkpoints/GR00T-N1.5-3B',
    vlm_backbone=dict(
        type='EagleBackbone',
        vlm_path=
        'fluxvla/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=64,
        hidden_size=1024,
        input_embedding_dim=1536,
        num_layers=1,
        num_heads=4,
        num_inference_timesteps=4,
        num_steps=32,
        traj_length=10,
        action_dim=32,
        ori_action_dim=14),
    freeze_vlm_backbone=False,
    name_mapping={
        'vlm_backbone.vlm': 'backbone.eagle_model',
        'vla_head': 'action_head'
    },
    freeze_projector=False)

架构解读

组件

选择

说明

VLA

LlavaVLA

支持 VLM backbone 的连续动作预测 VLA

Backbone

EagleBackbone

Eagle2.5-VL 多模态模型

Head

FlowMatchingHead

Flow Matching 生成式动作头 (DiT)

训练策略

全参数微调

freeze_vlm_backbone=False

示例 2:PI0 架构(PI0FlowMatching + PaliGemma)#

model = dict(
    type='PI0FlowMatching',
    vlm_backbone=dict(
        type='PaliGemma',
        vlm_backbone_id='paligemma_3b_pt_224',
        vlm_config=dict(
            text_config=dict(
                num_hidden_layers=18,
                hidden_size=2048,
                intermediate_size=16384,
                num_attention_heads=8,
                num_key_value_heads=1))),
    proj_width=1024,
    n_action_steps=32,
    state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
    llm_expert=dict(
        type='GemmaLLMBackbone',
        llm_backbone_id='gemma-2b_causal',
        llm_max_length=256),
    freeze_vlm_backbone=True,
    pretrained_name_or_path='./checkpoints/pi0_pretrained',
    name_mapping={
        'vlm_backbone.vlm': 'vla.vlm',
        'llm_expert.llm': 'vla.language_model'
    })

架构解读

组件

选择

说明

VLA

PI0FlowMatching

PI0 Flow Matching 架构

Backbone

PaliGemma

PaliGemma 3B VLM

Expert

GemmaLLMBackbone

额外的 Gemma LLM 用于动作解码

投影

LinearProjector

状态/动作的线性投影

训练策略

冻结 VLM

freeze_vlm_backbone=True

示例 3:OpenVLA 架构(Vision + LLM 分离)#

model = dict(
    type='OpenVLA',
    arch_specifier='no-align+fused-gelu-mlp',
    vision_backbone=dict(
        type='DinoSigLIPViTBackbone',
        vision_backbone_id='dinosiglip-vit-so-224px',
        dino_config=dict(image_size=224, patch_size=14),
        siglip_config=dict(image_size=224, patch_size=14)),
    llm_backbone=dict(
        type='LLaMa2LLMBackbone',
        llm_backbone_id='llama2-7b-pure_causal',
        llm_max_length=2048),
    projector=dict(
        type='FusedMLPProjector',
        fused_vision_dim=2176,
        llm_dim=4096),
    vla_head=dict(
        type='OpenVLAHead',
        norm_stats=None,
        vocab_size=32000),
    freeze_vision_backbone=False,
    freeze_llm_backbone=False,
    freeze_projector=False)

架构解读

组件

选择

说明

VLA

OpenVLA

基于 token 的动作预测 VLA

视觉

DinoSigLIPViTBackbone

DINO + SigLIP 融合视觉编码器

语言

LLaMa2LLMBackbone

LLaMA-2 7B 语言模型

投影

FusedMLPProjector

融合 MLP 投影

Head

OpenVLAHead

Token 离散化动作头


核心检查清单#

添加新的模型组件后,请按以下清单逐项确认:

  • [ ] 模型文件已创建在正确的目录下(vlas/backbones/vlms/heads/ 等)

  • [ ] 使用了对应的注册装饰器(@VLAS.register_module()@VLM_BACKBONES.register_module()@HEADS.register_module() 等)

  • [ ] 在对应的 __init__.py 中添加了 import 语句

  • [ ] VLA 模型实现了 forward()get_fsdp_wrapping_policy()

  • [ ] VLM Backbone 的 forward() 返回 (features, attention_mask, extra) 三元组

  • [ ] VLA Head 实现了 forward()predict_action() 方法

  • [ ] forward() 返回包含 'loss' 键的字典

  • [ ] predict_action() 返回动作张量

  • [ ] 配置文件中的 type 与注册类名完全一致

  • [ ] 配置文件中的参数名与 __init__ 方法的参数名一致

  • [ ] 如果使用预训练权重,name_mapping 正确映射了参数名


常见问题#

Q:如何选择继承 BaseVLA 还是已有的 VLA?#

  • 如果你的模型与现有 VLA 架构完全不同(如新的训练范式),继承 BaseVLA

  • 如果你的模型是现有架构的变体(如换个 backbone 或 head),直接通过配置文件切换子模块即可,无需创建新的 VLA 类

  • 如果你需要修改 forward 逻辑但复用大部分代码,继承最相近的 VLA(如 LlavaVLA

Q:添加新 Head 后模型配置需要改什么?#

只需要修改 model.vla_head 字典中的 type 和对应参数:

model = dict(
    type='LlavaVLA',                    # VLA 不变
    vlm_backbone=dict(...),             # backbone 不变
    vla_head=dict(
        type='MyActionHead',            # ← 换成你的 Head
        hidden_size=1024,
        state_dim=64,
        action_dim=32,
        input_embedding_dim=1536,
        traj_length=10,
        ori_action_dim=14),
    ...)

Q:inference_model 是什么?#

inference_model 是可选的推理专用模型配置。当推理时需要不同的模型设置(例如使用 FlowMatchingInferenceHead 替代 FlowMatchingHead 以启用 CUDA Graph 加速)时使用:

model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingHead', ...), ...)
inference_model = dict(type='LlavaVLA', vla_head=dict(type='FlowMatchingInferenceHead', ...), ...)

如果未定义 inference_model,推理时默认使用 model 的配置。

Q:如何确保我的 Head 与 FSDP 训练兼容?#

如果你的 Head 包含大量参数,需要实现 get_fsdp_wrapping_policy() 方法,指定哪些子模块应被 FSDP 单独分片:

def get_fsdp_wrapping_policy(self):
    from functools import partial
    from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
    return partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={MyTransformerBlock})

如果 Head 参数量较小,返回 None 即可。

Q:name_mapping 的映射方向是怎样的?#

name_mapping 的格式为 {FluxVLA 中的模块名: 预训练 checkpoint 中的模块名}

name_mapping={
    'vlm_backbone.vlm': 'backbone.eagle_model',
    # 含义:FluxVLA 的 vlm_backbone.vlm.xxx ← checkpoint 的 backbone.eagle_model.xxx
}

加载权重时,框架会将 checkpoint 中 backbone.eagle_model. 前缀的参数重命名为 vlm_backbone.vlm. 后再加载。