π₀.₅ 模型训练#

本教程提供了一份完整指南,介绍如何使用 FluxVLA 框架对 π₀.₅ 模型进行微调训练。π₀.₅ 是 π₀ 的改进版本(PI05FlowMatching 继承自 PI0FlowMatching),两者共享相同的架构设计,π₀.₅ 在动作生成上做了进一步优化。内容涵盖两种典型场景:

  1. LIBERO 仿真数据训练 — 使用开箱即用的 LIBERO 基准数据集进行快速验证(详见 LIBERO 仿真数据训练,支持 Full Finetune 和 LoRA)

  2. 私有真机数据训练 — 使用自己采集的机器人数据进行定制化微调(支持单机器人和多机器人混合训练)

模型架构#

π₀.₅ 在 FluxVLA 中由以下核心模块组成:

模块

类型

作用

VLA 容器

PI05FlowMatching

模型外壳,协调视觉语言理解和动作生成(继承自 PI0FlowMatching

视觉语言骨干

PaliGemma

PaliGemma 3B 多模态模型,处理图像和语言输入

动作专家

GemmaLLMBackbone

Gemma 2B 语言模型,作为动作生成的专家网络

状态/动作投影器

LinearProjector

将机器人状态和动作映射到模型内部维度

与 GR00T-N1.5 的对比:

特性

π₀.₅

GR00T-N1.5

模型类型

PI05FlowMatching

LlavaVLA

视觉骨干

PaliGemma(SigLIP + Gemma)

EagleBackbone(Eagle2)

动作生成

独立的 Gemma Expert + LinearProjector

FlowMatchingHead

文本处理

ProcessPrompts

ProcessPromptsWithImage

预训练权重

pi0/pi0.5 checkpoint(.safetensors

GR00T-N1.5-3B

场景二:私有真机数据训练#

本节演示如何使用私有数据训练 π₀ 模型。π₀ 提供了 Aloha 双臂、UR3 单臂、以及 Aloha + UR3 混合训练的配置文件。

1. 数据准备#

私有数据需转换为 LeRobot v2.1 格式(参见 真机数据准备)。

2. Aloha 双臂训练#

使用配置文件 configs/pi0/pi0_paligemma_aloha_full_train.py

2.1 模型配置差异#

与 LIBERO 配置相比,Aloha 需要调整维度参数:

model = dict(
    type='PI0FlowMatching',
    # ... PaliGemma 和 Gemma Expert 配置不变 ...
    proj_width=1024,
    n_action_steps=32,                # Aloha 使用更长的动作序列
    state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
    max_action_dim=14,                # 双臂最大动作维度
    # ...
    freeze_vlm_backbone=True,         # 冻结骨干,只训练动作相关层
    pretrained_name_or_path=
        '/path/to/checkpoints/pi0/model.safetensors',
)

维度参数对比:

参数

LIBERO

Aloha 双臂

说明

state_proj.in_dim

8

14

双臂各 7 维状态

action_in_proj.in_dim

7

14

双臂各 7 维动作

action_out_proj.out_dim

7

14

输出维度与输入一致

n_action_steps

10

32

真机场景使用更长的动作序列

max_action_dim

不设置

14

限制最大动作维度

备注

Aloha 配置使用 freeze_vlm_backbone=True 冻结视觉语言骨干,仅训练动作相关的投影层和专家网络。这是因为从 pi0 官方 checkpoint 出发,骨干已经过充分预训练。

2.2 数据配置#

train_dataloader = dict(
    batch_size=128,
    per_device_batch_size=8,
    per_device_num_workers=4,
    dataset=dict(
        type='DistributedRepeatingDataset',
        name_mappings={'observation.state': ['proprio', 'action']},
        statistic_keys=[
            'observation.state', 'observation.eepose', 'timestamp'
        ],
        datasets=[
            dict(
                type='ParquetDataset',
                data_root_path=[
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250601_20250615_02_4090',
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250616_20250630_02_4090',
                    # ... 可添加更多数据批次
                ],
                transforms=[
                    dict(
                        type='ProcessParquetInputs',
                        parquet_keys=[
                            'observation.state', 'observation.eepose',
                            'timestamp', 'actions', 'info', 'stats',
                            'action_masks'
                        ],
                        video_keys=[
                            'observation.images.cam_high',
                            'observation.images.cam_left_wrist',
                            'observation.images.cam_right_wrist'
                        ],
                        name_mappings={'observation.state': ['states']}),
                    dict(type='ParquetPrompter'),
                    dict(
                        type='ProcessPrompts',
                        tokenizer=dict(
                            type='PretrainedTokenizer',
                            model_path='/path/to/checkpoints/pi0',
                        )),
                    dict(type='ResizeImages', height=224, width=224),
                    dict(
                        type='NormalizeImages',
                        means=[[123.515625, 116.04492188, 103.59375],
                               [123.515625, 116.04492188, 103.59375],
                               [123.515625, 116.04492188, 103.59375]],
                        stds=[[58.27148438, 57.02636719, 57.27539062],
                              [58.27148438, 57.02636719, 57.27539062],
                              [58.27148438, 57.02636719, 57.27539062]]),
                    dict(
                        type='NormalizeStatesAndActions',
                        action_dim=14,
                        state_key='proprio',
                        action_key='action',
                        use_quantiles=False)
                ],
                action_window_size=32)
        ]))

2.3 推理配置#

inference = dict(
    type='AlohaInferenceRunner',
    seed=7,
    task_descriptions={
        '1': 'pick up the yellow chicken toy with left arm',
        '2': 'place it in the brown flat cardboard box with right arm',
        # ... 根据实际任务添加
    },
    dataset=dict(
        type='PrivateInferenceDataset',
        img_keys=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
        transforms=[
            dict(type='PrivatePrompter'),
            dict(
                type='ProcessPrompts',
                tokenizer=dict(type='PretrainedTokenizer')),
            dict(type='ResizeImages', height=224, width=224),
            dict(
                type='NormalizeImages',
                means=[[123.515625, 116.04492188, 103.59375],
                       [123.515625, 116.04492188, 103.59375],
                       [123.515625, 116.04492188, 103.59375]],
                stds=[[58.27148438, 57.02636719, 57.27539062],
                      [58.27148438, 57.02636719, 57.27539062],
                      [58.27148438, 57.02636719, 57.27539062]]),
            dict(
                type='NormalizeStatesAndActions',
                action_dim=16,
                state_key='proprio',
                action_key='action',
                use_quantiles=False)
        ]),
    denormalize_action=dict(
        type='DenormalizePrivateAction',
        use_quantiles=False),
    operator=dict(
        type='AlohaOperator',
        img_front_topic='/camera_f/color/image_raw',
        img_left_topic='/camera_l/color/image_raw',
        img_right_topic='/camera_r/color/image_raw',
        img_front_depth_topic='/camera_f/depth/image_raw',
        img_left_depth_topic='/camera_l/depth/image_raw',
        img_right_depth_topic='/camera_r/depth/image_raw',
        puppet_arm_left_cmd_topic='/master/joint_left',
        puppet_arm_right_cmd_topic='/master/joint_right',
        puppet_arm_left_topic='/puppet/joint_left',
        puppet_arm_right_topic='/puppet/joint_right',
        robot_base_topic='/odom_raw',
        robot_base_cmd_topic='/cmd_vel'))

3. UR3 单臂训练#

使用配置文件 configs/pi0/pi0_paligemma_ur3_full_train.py

与 Aloha 配置的关键差异:

配置项

Aloha 双臂

UR3 单臂

说明

state_proj.in_dim

14

7

UR3 单臂状态维度

action_in_proj.in_dim

14

7

UR3 单臂动作维度

action_out_proj.out_dim

14

7

输出动作维度

max_action_dim

14

7

最大动作维度

video_keys

3 个相机

2 个相机

cam_high + cam_wrist

inference.type

AlohaInferenceRunner

URInferenceRunner

不同推理 Runner

operator.type

AlohaOperator

UROperator

不同硬件通信接口

4. 多机器人混合训练#

π₀ 支持将多种机器人的数据混合训练,使模型具备跨机器人的泛化能力。使用配置文件 configs/pi0/pi0_paligemma_aloha+ur_full_train.py

混合训练的关键配置:

model = dict(
    type='PI0FlowMatching',
    # ...
    n_action_steps=32,
    state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
    action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
    max_action_dim=14,               # 取所有机器人中的最大维度
    # ...
)

数据集使用分组字典格式,按机器人类型组织:

train_dataloader = dict(
    dataset=dict(
        type='DistributedRepeatingDataset',
        name_mappings={'observation.state': ['proprio', 'action']},
        dim=14,                      # 统一的最大维度
        datasets=dict(
            aloha_4090=[             # Aloha 4090 数据组
                dict(type='ParquetDataset', data_root_path='...', ...),
            ],
            aloha_4060=[             # Aloha 4060 数据组
                dict(type='ParquetDataset', data_root_path='...', ...),
                dict(type='ParquetDataset', data_root_path='...', ...),
            ],
            ur3=[                    # UR3 数据组
                dict(type='ParquetDataset', data_root_path='...', ...),
            ]
        )))

警告

混合训练时,不同维度的机器人数据需要通过 PadKeyToDim 变换对齐到统一维度。例如 UR3 的 7 维数据需要 padding 到 14 维:

# 在 UR3 的 transforms 中添加
dict(
    type='PadKeyToDim',
    keys=['actions', 'states', 'observation.eepose'],
    dim=14),

5. 启动训练和部署#

cd /path/to/fluxvla

export MLP_WORKER_GPU=8
export MLP_WORKER_NUM=1
export MLP_ROLE_INDEX=0
export MLP_WORKER_0_HOST=localhost
export MLP_WORKER_0_PORT=29500

# Aloha 训练
bash scripts/train.sh \
    configs/pi0/pi0_paligemma_aloha_full_train.py \
    work_dirs/pi0_paligemma_aloha_full_train

# UR3 训练
bash scripts/train.sh \
    configs/pi0/pi0_paligemma_ur3_full_train.py \
    work_dirs/pi0_paligemma_ur3_full_train

# Aloha + UR3 混合训练
bash scripts/train.sh \
    configs/pi0/pi0_paligemma_aloha+ur_full_train.py \
    work_dirs/pi0_paligemma_aloha+ur_full_train

真机部署:

python scripts/inference_real_robot.py \
    --config configs/pi0/pi0_paligemma_aloha_full_train.py \
    --ckpt-path work_dirs/pi0_paligemma_aloha_full_train/checkpoint_epoch_3.pt

可用配置汇总#

配置文件

数据

训练策略

机器人

pi0_paligemma_libero_10_full_train.py

LIBERO-10

Full(从零训练)

仿真

pi0_paligemma_libero_10_full_finetune_pytorch.py

LIBERO-10

Full Finetune

仿真

pi0_paligemma_libero_10_full_finetune.py

LIBERO-10

Full Finetune

仿真

pi0_paligemma_libero_10_lora_finetune.py

LIBERO-10

LoRA

仿真

pi0_paligemma_libero_90_full_finetune_pytorch.py

LIBERO-90

Full Finetune

仿真

pi0_paligemma_libero_spatial_full_finetune_pytorch.py

LIBERO-Spatial

Full Finetune

仿真

pi0_paligemma_libero_object_full_finetune_pytorch.py

LIBERO-Object

Full Finetune

仿真

pi0_paligemma_libero_goal_full_finetune_pytorch.py

LIBERO-Goal

Full Finetune

仿真

pi0_paligemma_aloha_full_train.py

Aloha 私有数据

Full

Aloha 双臂

pi0_paligemma_ur3_full_train.py

UR3 私有数据

Full

UR3 单臂

pi0_paligemma_aloha+ur_full_train.py

Aloha + UR3

Full(混合)

多机器人

常见问题#

Q:π₀ 的 Full Finetune 和 LoRA 如何选择?

  • Full Finetune 效果通常更好,但需要更多显存和训练时间。适合最终部署使用。

  • LoRA 显存需求低(DDP 即可),训练速度快。适合快速实验和超参搜索。LoRA 使用较高学习率(如 5e-4 vs 2e-5)。

Q:私有数据训练时为什么要冻结骨干(freeze_vlm_backbone=True)?

π₀ 的 PaliGemma 骨干已经过大规模预训练,视觉和语言理解能力很强。在私有数据量有限时,冻结骨干可以避免过拟合,同时大幅减少可训练参数量。如果数据充足,也可以设为 False 进行全参数微调。

Q:混合训练时不同维度的数据如何处理?

使用 PadKeyToDim 变换将低维数据 padding 到最大维度。模型中设置 max_action_dim 为所有机器人中的最大动作维度。框架会自动处理 padding 后的维度对齐。

Q:π₀ 和 GR00T 的预训练权重来源不同?

  • π₀ 权重来自 openpi_pytorch 项目的转换版本,格式为 .safetensors

  • GR00T-N1.5 权重来自 NVIDIA 发布的 GR00T-N1.5-3B checkpoint

两者在 FluxVLA 中通过各自的 name_mapping 进行权重映射。

Q:如何从中断的训练恢复?

bash scripts/train.sh \
    configs/pi0/pi0_paligemma_aloha_full_train.py \
    work_dirs/pi0_paligemma_aloha_full_train \
    --resume-from work_dirs/pi0_paligemma_aloha_full_train/checkpoint_epoch_2.pt

Q:如何覆盖配置参数?

bash scripts/train.sh \
    configs/pi0/pi0_paligemma_libero_10_full_finetune_pytorch.py \
    work_dirs/pi0_custom \
    --cfg-options runner.max_epochs=50 runner.learning_rate=1e-5