GR00T-N1.5 模型训练#

本教程提供了一份完整指南,介绍如何使用 FluxVLA 框架对 GR00T-N1.5 模型进行微调训练。内容涵盖两种典型场景:

  1. LIBERO 仿真数据训练 — 使用开箱即用的 LIBERO 基准数据集进行快速验证, LIBERO 仿真数据训练

  2. 私有真机数据训练 — 使用自己采集的机器人数据进行定制化微调, (参见 真机数据准备)。

模型架构#

GR00T-N1.5 在 FluxVLA 中由以下三个核心模块组成:

模块

类型

作用

VLA 容器

LlavaVLA

模型外壳,负责加载预训练权重、协调各子模块

视觉语言骨干

EagleBackbone

Eagle2 多模态模型,处理图像特征提取与语言理解

动作头

FlowMatchingHead

基于 Flow Matching 的动作生成器,将视觉语言特征映射为机器人动作序列

模型通过 name_mapping 将预训练 checkpoint 中的权重映射到 FluxVLA 的模块结构:

name_mapping={
    'vlm_backbone.vlm': 'backbone.eagle_model',
    'vla_head': 'action_head'
}

场景二:私有真机数据训练#

本节以 Aloha 双臂机器人为例,演示如何使用自己采集的数据训练 GR00T-N1.5 模型并进行真机部署。UR3 单臂等其他机器人的流程类似,只需调整维度和相机参数。

1. 数据准备#

1.1 数据格式要求#

私有数据需转换为 LeRobot v2.1 格式。如果原始数据为 HDF5 格式,可使用数据转换脚本进行转换(参见 [真机数据准备])。

转换后的目录结构:

your_dataset_path/
├── data/
│   └── chunk-000/
│       ├── episode_000000.parquet
│       └── ...
├── videos/
│   └── chunk-000/
│       ├── observation.images.cam_high/
│       │   ├── episode_000000.mp4
│       │   └── ...
│       ├── observation.images.cam_left_wrist/
│       │   └── ...
│       └── observation.images.cam_right_wrist/
│           └── ...
└── meta/
    ├── episodes.jsonl
    ├── episodes_stats.jsonl
    ├── info.json
    └── tasks.jsonl

1.2 确认关键信息#

在开始配置前,请先明确以下数据特征:

信息

Aloha 双臂示例

UR3 单臂示例

数据集路径

RealRobot_AgileX_aloha_lerobot_v2/...

RealRobot_UR3_lerobot_v2/...

相机名称

cam_high, cam_left_wrist, cam_right_wrist

cam_high, cam_wrist

相机数量

3

2

动作维度

14(左臂 7 + 右臂 7)

7(6 关节 + 1 夹爪)

状态维度

14

7

2. 配置文件详解#

configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py 为例进行说明。

2.1 模型配置#

model = dict(
    type='LlavaVLA',
    pretrained_name_or_path='/path/to/models/GR00T-N1.5-3B',
    vlm_backbone=dict(
        type='EagleBackbone',
        vlm_path='FluxVLA/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=14,             # Aloha 双臂状态维度
        hidden_size=1024,
        input_embedding_dim=1536,
        num_layers=1,
        num_heads=4,
        num_inference_timesteps=4,
        num_step=32,              # Flow Matching 采样步数
        traj_length=10,
        action_dim=14),           # Aloha 双臂动作维度
    freeze_vlm_backbone=False,
    name_mapping={
        'vlm_backbone.vlm': 'backbone.eagle_model',
        'vla_head': 'action_head'
    },
    freeze_projector=False)

与 LIBERO 配置的差异:

参数

LIBERO

Aloha

说明

state_dim

8

14

机器人状态维度

action_dim

7

14

机器人动作维度

num_step

未设置(默认)

32

Flow Matching 扩散步数,更复杂的动作空间需要更多步

traj_length

10

10

预测轨迹长度,两者一致

推理时可额外定义 inference_model,格式与 model 完全一致,一般只需增加 dtype=None 参数:

inference_model = dict(
    type='LlavaVLA',
    pretrained_name_or_path='/path/to/models/GR00T-N1.5-3B',
    vlm_backbone=dict(
        type='EagleBackbone',
        dtype=None,
        vlm_path='FluxVLA/models/third_party_models/eagle2_hg_model'),
    vla_head=dict(
        type='FlowMatchingHead',
        state_dim=14,
        hidden_size=1024,
        input_embedding_dim=1536,
        num_layers=1,
        num_heads=4,
        num_step=32,
        num_inference_timesteps=4,
        traj_length=10,
        action_dim=14),
    freeze_vlm_backbone=False,
    name_mapping={
        'vlm_backbone.vlm': 'backbone.eagle_model',
        'vla_head': 'action_head'
    },
    freeze_projector=False)

2.2 数据配置#

train_dataloader = dict(
    batch_size=128,
    per_device_batch_size=8,
    per_device_num_workers=4,
    dataset=dict(
        type='DistributedRepeatingDataset',
        name_mappings={'observation.state': ['proprio', 'action']},
        statistic_keys=[
            'observation.state', 'observation.eepose', 'timestamp'
        ],
        datasets=[
            dict(
                type='ParquetDataset',
                data_root_path=[
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250601_20250615_02_4090',
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250616_20250630_02_4090',
                    # ... 可添加更多数据批次
                ],
                transforms=[
                    dict(
                        type='ProcessParquetInputs',
                        embodiment_id=30,
                        parquet_keys=[
                            'observation.state', 'observation.eepose',
                            'timestamp', 'actions', 'info', 'stats',
                            'action_masks'
                        ],
                        video_keys=[
                            'observation.images.cam_high',
                            'observation.images.cam_left_wrist',
                            'observation.images.cam_right_wrist'
                        ],
                        name_mappings={'observation.state': ['states']}),
                    dict(type='ParquetPrompter'),
                    dict(
                        type='ProcessPromptsWithImage',
                        max_len=900,
                        num_images=3,
                        tokenizer=dict(
                            type='PretrainedTokenizer',
                            model_path='/path/to/models/eagle2_hg_model',
                        )),
                    dict(type='ResizeImages', height=224, width=224),
                    dict(
                        type='NormalizeImages',
                        means=[[123.515625, 116.04492188, 103.59375],
                               [123.515625, 116.04492188, 103.59375],
                               [123.515625, 116.04492188, 103.59375]],
                        stds=[[58.27148438, 57.02636719, 57.27539062],
                              [58.27148438, 57.02636719, 57.27539062],
                              [58.27148438, 57.02636719, 57.27539062]]),
                    dict(
                        type='NormalizeStatesAndActions',
                        action_dim=14,
                        state_key='proprio',
                        action_key='action',
                        use_quantiles=False)
                ],
                action_window_size=32)
        ]))

与 LIBERO 数据配置的关键差异:

配置项

LIBERO

Aloha 私有数据

适配要点

data_root_path

单个路径(字符串)

路径列表

私有数据通常按批次存放

name_mappings

state → proprio; action → action

state → [proprio, action]

LIBERO 状态和动作分离;私有数据中 state 同时作为 proprio 和 action 参考

embodiment_id

31

30(Aloha)

不同机器人平台使用不同 ID

video_keys

2 个相机

3 个相机

根据实际硬件配置

num_images

2

3

video_keys 数量一致

max_len

600

900

更多相机需要更长的 token 序列

action_window_size

10

32

真机场景通常需要更长的动作窗口

备注

NormalizeImagesmeansstds 列表的长度必须等于相机数量。使用 Eagle 骨干时,上述默认值适用于所有相机。

2.3 训练配置#

runner = dict(
    type='FSDPTrainRunner',
    max_epochs=3,                     # 真机数据量大时 epoch 可较少
    learning_rate=2e-5,
    weight_decay=0.0,
    max_grad_norm=1.0,
    sampler=None,
    tokenizer=dict(
        type='PretrainedTokenizer',
        model_path='/path/to/models/eagle2_hg_model',
    ),
    collator=dict(
        type='DictCollator',
        keys=[
            'states', 'observation.eepose', 'timestamp', 'images',
            'img_masks', 'lang_tokens', 'lang_masks', 'actions',
            'action_masks', 'embodiment_ids'
        ],
        meta_keys=['task_description', 'prompt', 'info', 'stats']),
    metric=dict(
        type='VLAMetric',
        active_trackers=('jsonl', 'wandb'),
        run_dir='work_dirs',
        wandb_project='FluxVLA',
        wandb_entity='limx',
        grad_accumulation_steps=1,
        window_size=1),
    lr_scheduler_type='constant',
    warmup_ratio=0.0,
    enable_gradient_checkpointing=True,
    enable_mixed_precision_training=True,
    mixed_precision_dtype='bf16',
    sharding_strategy='full-shard',
    change_key_name=False)

训练轮数建议:

数据规模

建议 max_epochs

说明

< 100 episodes

10 ~ 24

数据量少,需要更多轮次拟合

100 ~ 1000 episodes

3 ~ 6

中等规模,常见真机采集量

1000 episodes

1 ~ 3

大规模数据,少量轮次即可

2.4 推理配置#

inference = dict(
    type='AlohaInferenceRunner',
    seed=7,
    task_descriptions={
        '1': 'pick up the yellow chicken toy with left arm',
        '2': 'place it in the brown flat cardboard box with right arm',
        # ... 根据实际任务添加更多
    },
    dataset=dict(
        type='PrivateInferenceDataset',
        embodiment_id=30,
        img_keys=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
        transforms=[
            dict(type='PrivatePrompter'),
            dict(
                type='ProcessPromptsWithImage',
                max_len=900,
                num_images=3,
                tokenizer=dict(type='PretrainedTokenizer')),
            dict(type='ResizeImages', height=224, width=224),
            dict(
                type='NormalizeImages',
                means=[[123.515625, 116.04492188, 103.59375],
                       [123.515625, 116.04492188, 103.59375],
                       [123.515625, 116.04492188, 103.59375]],
                stds=[[58.27148438, 57.02636719, 57.27539062],
                      [58.27148438, 57.02636719, 57.27539062],
                      [58.27148438, 57.02636719, 57.27539062]]),
            dict(
                type='NormalizeStatesAndActions',
                action_dim=14,
                state_key='proprio',
                action_key='action',
                use_quantiles=False)
        ]),
    denormalize_action=dict(
        type='DenormalizePrivateAction',
        use_quantiles=False),
    operator=dict(
        type='AlohaOperator',
        img_front_topic='/camera_f/color/image_raw',
        img_left_topic='/camera_l/color/image_raw',
        img_right_topic='/camera_r/color/image_raw',
        img_front_depth_topic='/camera_f/depth/image_raw',
        img_left_depth_topic='/camera_l/depth/image_raw',
        img_right_depth_topic='/camera_r/depth/image_raw',
        puppet_arm_left_cmd_topic='/master/joint_left',
        puppet_arm_right_cmd_topic='/master/joint_right',
        puppet_arm_left_topic='/puppet/joint_left',
        puppet_arm_right_topic='/puppet/joint_right',
        robot_base_topic='/odom_raw',
        robot_base_cmd_topic='/cmd_vel'))

推理配置要点:

  • task_descriptions — 键为任务 ID(字符串),值为自然语言描述。推理时通过键盘选择任务

  • PrivateInferenceDataset — 用于真机推理的数据集类型,会加载 dataset_statistics.json 进行反归一化

  • PrivatePrompter — 替代训练时的 ParquetPrompter,从推理输入构建提示

  • DenormalizePrivateAction — 将模型输出的归一化动作转换为真实机器人动作

  • operator — 定义 ROS topic 通信接口,需根据实际硬件修改

警告

推理时的数据预处理参数(归一化方式、图像尺寸、action_dim 等)必须与训练时保持一致,否则推理效果会严重下降。

3. 启动训练#

cd /path/to/fluxvla

# 单节点 8 GPU 训练
export MLP_WORKER_GPU=8
export MLP_WORKER_NUM=1
export MLP_ROLE_INDEX=0
export MLP_WORKER_0_HOST=localhost
export MLP_WORKER_0_PORT=29500

bash scripts/train.sh \
    configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py \
    work_dirs/gr00t_eagle_3b_aloha_4090_full_train

多节点训练示例(2 节点 × 8 GPU):

节点 0(主节点):

export MLP_WORKER_GPU=8
export MLP_WORKER_NUM=2
export MLP_ROLE_INDEX=0
export MLP_WORKER_0_HOST={MASTER_NODE_IP}
export MLP_WORKER_0_PORT=29500

bash scripts/train.sh \
    configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py \
    work_dirs/gr00t_eagle_3b_aloha_4090_full_train

节点 1:

export MLP_WORKER_GPU=8
export MLP_WORKER_NUM=2
export MLP_ROLE_INDEX=1
export MLP_WORKER_0_HOST={MASTER_NODE_IP}
export MLP_WORKER_0_PORT=29500

bash scripts/train.sh \
    configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py \
    work_dirs/gr00t_eagle_3b_aloha_4090_full_train

4. 真机部署#

使用训练好的 checkpoint 进行推理:

python scripts/inference_real_robot.py \
    --config configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py \
    --ckpt-path work_dirs/gr00t_eagle_3b_aloha_4090_full_train/checkpoint_epoch_3.pt

适配其他机器人:UR3 单臂示例#

configs/gr00t/gr00t_eagle_3b_ur3_full_train.py 为参考,UR3 的配置与 Aloha 相比有以下关键差异:

配置项

Aloha 双臂

UR3 单臂

说明

state_dim / action_dim

14

7

UR3 只有 1 条机械臂

embodiment_id

30

31

不同机器人平台的标识

video_keys

3 个相机

2 个相机(cam_high, cam_wrist

UR3 通常使用头部 + 腕部相机

num_images

3

2

与相机数量一致

max_len

900

600

更少相机需要更短的 token 序列

inference.type

AlohaInferenceRunner

URInferenceRunner

不同的推理 Runner

operator.type

AlohaOperator

UROperator

不同的硬件通信接口

适配新机器人时,只需创建新的配置文件,调整上述参数即可。如果现有的推理 Runner 和 Operator 不能满足需求,可参考 :doc:../tutorials/private_engine 教程进行扩展。

可用配置汇总#

配置文件

数据

骨干

机器人

gr00t_eagle_3b_libero_10_full_train.py

LIBERO-10

Eagle

仿真(LIBERO)

gr00t_paligemma_3b_libero_10_full_train.py

LIBERO-10

PaliGemma

仿真(LIBERO)

gr00t_eagle_3b_aloha_4090_full_train.py

Aloha 私有数据

Eagle

Aloha 双臂

gr00t_eagle_3b_aloha_fold_towel_3cam_4090_full_train.py

Aloha 叠毛巾(3 相机)

Eagle

Aloha 双臂

gr00t_eagle_3b_aloha_fold_towel_4cam_4090_full_train.py

Aloha 叠毛巾(4 相机)

Eagle

Aloha 双臂

gr00t_eagle_3b_ur3_full_train.py

UR3 私有数据

Eagle

UR3 单臂

配置 Weights & Biases#

训练过程中的 loss 曲线和指标可通过 Weights & Biases 进行跟踪和可视化。

# 安装
pip install wandb

# 登录
wandb login

# 设置项目信息
export WANDB_PROJECT=FluxVLA
export WANDB_ENTITY=your-team-name
export WANDB_MODE=online

# 如需禁用
export WANDB_MODE=disabled

常见问题#

Q:LIBERO 训练和私有数据训练可以混合吗?

可以。在 datasets 列表中同时配置 LIBERO 和私有数据的 ParquetDataset 即可。但需注意两者的 embodiment_idvideo_keysaction_dim 等参数可能不同,需要分别配置各自的 transforms

Q:如何从中断的训练恢复?

使用 --resume-from 参数指定 checkpoint 路径:

bash scripts/train.sh \
    configs/gr00t/gr00t_eagle_3b_aloha_4090_full_train.py \
    work_dirs/gr00t_eagle_3b_aloha_4090_full_train \
    --resume-from work_dirs/gr00t_eagle_3b_aloha_4090_full_train/checkpoint_epoch_2.pt

Q:显存不足怎么办?

可尝试以下方法:

  1. 减小 per_device_batch_size(如从 8 降至 4 或 2)

  2. 开启 enable_gradient_checkpointing=True(以时间换空间)

  3. 调整 FSDP 分片策略(sharding_strategy='full-shard'

Q:GR00T-N1.5 和 π₀ 有什么区别?

两者在 FluxVLA 中共享相同的 VLA 头(FlowMatchingHead),主要区别在于视觉语言骨干:GR00T-N1.5 使用 EagleBackbone(Eagle2 模型),π₀ 使用 PaliGemma(PaliGemma 模型)。Eagle 骨干通常在多图像场景下表现更好。

Q:如何覆盖配置文件中的参数?

通过 --cfg-options 在命令行动态覆盖:

bash scripts/train.sh \
    configs/gr00t/gr00t_eagle_3b_libero_10_full_train.py \
    work_dirs/gr00t_custom \
    --cfg-options runner.max_epochs=50 runner.learning_rate=1e-5

ib