π₀.₅ 模型训练#
本教程提供了一份完整指南,介绍如何使用 FluxVLA 框架对 π₀.₅ 模型进行微调训练。π₀.₅ 是 π₀ 的改进版本(PI05FlowMatching 继承自 PI0FlowMatching),两者共享相同的架构设计,π₀.₅ 在动作生成上做了进一步优化。内容涵盖两种典型场景:
LIBERO 仿真数据训练 — 使用开箱即用的 LIBERO 基准数据集进行快速验证(详见 LIBERO 仿真数据训练,支持 Full Finetune 和 LoRA)
私有真机数据训练 — 使用自己采集的机器人数据进行定制化微调(支持单机器人和多机器人混合训练)
模型架构#
π₀.₅ 在 FluxVLA 中由以下核心模块组成:
模块 |
类型 |
作用 |
|---|---|---|
VLA 容器 |
|
模型外壳,协调视觉语言理解和动作生成(继承自 |
视觉语言骨干 |
|
PaliGemma 3B 多模态模型,处理图像和语言输入 |
动作专家 |
|
Gemma 2B 语言模型,作为动作生成的专家网络 |
状态/动作投影器 |
|
将机器人状态和动作映射到模型内部维度 |
与 GR00T-N1.5 的对比:
特性 |
π₀.₅ |
GR00T-N1.5 |
|---|---|---|
模型类型 |
|
|
视觉骨干 |
|
|
动作生成 |
独立的 Gemma Expert + LinearProjector |
|
文本处理 |
|
|
预训练权重 |
pi0/pi0.5 checkpoint( |
GR00T-N1.5-3B |
场景二:私有真机数据训练#
本节演示如何使用私有数据训练 π₀ 模型。π₀ 提供了 Aloha 双臂、UR3 单臂、以及 Aloha + UR3 混合训练的配置文件。
1. 数据准备#
私有数据需转换为 LeRobot v2.1 格式(参见 真机数据准备)。
2. Aloha 双臂训练#
使用配置文件 configs/pi0/pi0_paligemma_aloha_full_train.py。
2.1 模型配置差异#
与 LIBERO 配置相比,Aloha 需要调整维度参数:
model = dict(
type='PI0FlowMatching',
# ... PaliGemma 和 Gemma Expert 配置不变 ...
proj_width=1024,
n_action_steps=32, # Aloha 使用更长的动作序列
state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
max_action_dim=14, # 双臂最大动作维度
# ...
freeze_vlm_backbone=True, # 冻结骨干,只训练动作相关层
pretrained_name_or_path=
'/path/to/checkpoints/pi0/model.safetensors',
)
维度参数对比:
参数 |
LIBERO |
Aloha 双臂 |
说明 |
|---|---|---|---|
|
8 |
14 |
双臂各 7 维状态 |
|
7 |
14 |
双臂各 7 维动作 |
|
7 |
14 |
输出维度与输入一致 |
|
10 |
32 |
真机场景使用更长的动作序列 |
|
不设置 |
14 |
限制最大动作维度 |
备注
Aloha 配置使用 freeze_vlm_backbone=True 冻结视觉语言骨干,仅训练动作相关的投影层和专家网络。这是因为从 pi0 官方 checkpoint 出发,骨干已经过充分预训练。
2.2 数据配置#
train_dataloader = dict(
batch_size=128,
per_device_batch_size=8,
per_device_num_workers=4,
dataset=dict(
type='DistributedRepeatingDataset',
name_mappings={'observation.state': ['proprio', 'action']},
statistic_keys=[
'observation.state', 'observation.eepose', 'timestamp'
],
datasets=[
dict(
type='ParquetDataset',
data_root_path=[
'/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250601_20250615_02_4090',
'/path/to/data/RealRobot_AgileX_aloha_lerobot_v2/20250616_20250630_02_4090',
# ... 可添加更多数据批次
],
transforms=[
dict(
type='ProcessParquetInputs',
parquet_keys=[
'observation.state', 'observation.eepose',
'timestamp', 'actions', 'info', 'stats',
'action_masks'
],
video_keys=[
'observation.images.cam_high',
'observation.images.cam_left_wrist',
'observation.images.cam_right_wrist'
],
name_mappings={'observation.state': ['states']}),
dict(type='ParquetPrompter'),
dict(
type='ProcessPrompts',
tokenizer=dict(
type='PretrainedTokenizer',
model_path='/path/to/checkpoints/pi0',
)),
dict(type='ResizeImages', height=224, width=224),
dict(
type='NormalizeImages',
means=[[123.515625, 116.04492188, 103.59375],
[123.515625, 116.04492188, 103.59375],
[123.515625, 116.04492188, 103.59375]],
stds=[[58.27148438, 57.02636719, 57.27539062],
[58.27148438, 57.02636719, 57.27539062],
[58.27148438, 57.02636719, 57.27539062]]),
dict(
type='NormalizeStatesAndActions',
action_dim=14,
state_key='proprio',
action_key='action',
use_quantiles=False)
],
action_window_size=32)
]))
2.3 推理配置#
inference = dict(
type='AlohaInferenceRunner',
seed=7,
task_descriptions={
'1': 'pick up the yellow chicken toy with left arm',
'2': 'place it in the brown flat cardboard box with right arm',
# ... 根据实际任务添加
},
dataset=dict(
type='PrivateInferenceDataset',
img_keys=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
transforms=[
dict(type='PrivatePrompter'),
dict(
type='ProcessPrompts',
tokenizer=dict(type='PretrainedTokenizer')),
dict(type='ResizeImages', height=224, width=224),
dict(
type='NormalizeImages',
means=[[123.515625, 116.04492188, 103.59375],
[123.515625, 116.04492188, 103.59375],
[123.515625, 116.04492188, 103.59375]],
stds=[[58.27148438, 57.02636719, 57.27539062],
[58.27148438, 57.02636719, 57.27539062],
[58.27148438, 57.02636719, 57.27539062]]),
dict(
type='NormalizeStatesAndActions',
action_dim=16,
state_key='proprio',
action_key='action',
use_quantiles=False)
]),
denormalize_action=dict(
type='DenormalizePrivateAction',
use_quantiles=False),
operator=dict(
type='AlohaOperator',
img_front_topic='/camera_f/color/image_raw',
img_left_topic='/camera_l/color/image_raw',
img_right_topic='/camera_r/color/image_raw',
img_front_depth_topic='/camera_f/depth/image_raw',
img_left_depth_topic='/camera_l/depth/image_raw',
img_right_depth_topic='/camera_r/depth/image_raw',
puppet_arm_left_cmd_topic='/master/joint_left',
puppet_arm_right_cmd_topic='/master/joint_right',
puppet_arm_left_topic='/puppet/joint_left',
puppet_arm_right_topic='/puppet/joint_right',
robot_base_topic='/odom_raw',
robot_base_cmd_topic='/cmd_vel'))
3. UR3 单臂训练#
使用配置文件 configs/pi0/pi0_paligemma_ur3_full_train.py。
与 Aloha 配置的关键差异:
配置项 |
Aloha 双臂 |
UR3 单臂 |
说明 |
|---|---|---|---|
|
14 |
7 |
UR3 单臂状态维度 |
|
14 |
7 |
UR3 单臂动作维度 |
|
14 |
7 |
输出动作维度 |
|
14 |
7 |
最大动作维度 |
|
3 个相机 |
2 个相机 |
|
|
|
|
不同推理 Runner |
|
|
|
不同硬件通信接口 |
4. 多机器人混合训练#
π₀ 支持将多种机器人的数据混合训练,使模型具备跨机器人的泛化能力。使用配置文件 configs/pi0/pi0_paligemma_aloha+ur_full_train.py。
混合训练的关键配置:
model = dict(
type='PI0FlowMatching',
# ...
n_action_steps=32,
state_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_in_proj=dict(type='LinearProjector', in_dim=14, out_dim=1024),
action_out_proj=dict(type='LinearProjector', in_dim=1024, out_dim=14),
max_action_dim=14, # 取所有机器人中的最大维度
# ...
)
数据集使用分组字典格式,按机器人类型组织:
train_dataloader = dict(
dataset=dict(
type='DistributedRepeatingDataset',
name_mappings={'observation.state': ['proprio', 'action']},
dim=14, # 统一的最大维度
datasets=dict(
aloha_4090=[ # Aloha 4090 数据组
dict(type='ParquetDataset', data_root_path='...', ...),
],
aloha_4060=[ # Aloha 4060 数据组
dict(type='ParquetDataset', data_root_path='...', ...),
dict(type='ParquetDataset', data_root_path='...', ...),
],
ur3=[ # UR3 数据组
dict(type='ParquetDataset', data_root_path='...', ...),
]
)))
警告
混合训练时,不同维度的机器人数据需要通过 PadKeyToDim 变换对齐到统一维度。例如 UR3 的 7 维数据需要 padding 到 14 维:
# 在 UR3 的 transforms 中添加
dict(
type='PadKeyToDim',
keys=['actions', 'states', 'observation.eepose'],
dim=14),
5. 启动训练和部署#
cd /path/to/fluxvla
export MLP_WORKER_GPU=8
export MLP_WORKER_NUM=1
export MLP_ROLE_INDEX=0
export MLP_WORKER_0_HOST=localhost
export MLP_WORKER_0_PORT=29500
# Aloha 训练
bash scripts/train.sh \
configs/pi0/pi0_paligemma_aloha_full_train.py \
work_dirs/pi0_paligemma_aloha_full_train
# UR3 训练
bash scripts/train.sh \
configs/pi0/pi0_paligemma_ur3_full_train.py \
work_dirs/pi0_paligemma_ur3_full_train
# Aloha + UR3 混合训练
bash scripts/train.sh \
configs/pi0/pi0_paligemma_aloha+ur_full_train.py \
work_dirs/pi0_paligemma_aloha+ur_full_train
真机部署:
python scripts/inference_real_robot.py \
--config configs/pi0/pi0_paligemma_aloha_full_train.py \
--ckpt-path work_dirs/pi0_paligemma_aloha_full_train/checkpoint_epoch_3.pt
可用配置汇总#
配置文件 |
数据 |
训练策略 |
机器人 |
|---|---|---|---|
|
LIBERO-10 |
Full(从零训练) |
仿真 |
|
LIBERO-10 |
Full Finetune |
仿真 |
|
LIBERO-10 |
Full Finetune |
仿真 |
|
LIBERO-10 |
LoRA |
仿真 |
|
LIBERO-90 |
Full Finetune |
仿真 |
|
LIBERO-Spatial |
Full Finetune |
仿真 |
|
LIBERO-Object |
Full Finetune |
仿真 |
|
LIBERO-Goal |
Full Finetune |
仿真 |
|
Aloha 私有数据 |
Full |
Aloha 双臂 |
|
UR3 私有数据 |
Full |
UR3 单臂 |
|
Aloha + UR3 |
Full(混合) |
多机器人 |
常见问题#
Q:π₀ 的 Full Finetune 和 LoRA 如何选择?
Full Finetune 效果通常更好,但需要更多显存和训练时间。适合最终部署使用。
LoRA 显存需求低(DDP 即可),训练速度快。适合快速实验和超参搜索。LoRA 使用较高学习率(如 5e-4 vs 2e-5)。
Q:私有数据训练时为什么要冻结骨干(freeze_vlm_backbone=True)?
π₀ 的 PaliGemma 骨干已经过大规模预训练,视觉和语言理解能力很强。在私有数据量有限时,冻结骨干可以避免过拟合,同时大幅减少可训练参数量。如果数据充足,也可以设为 False 进行全参数微调。
Q:混合训练时不同维度的数据如何处理?
使用 PadKeyToDim 变换将低维数据 padding 到最大维度。模型中设置 max_action_dim 为所有机器人中的最大动作维度。框架会自动处理 padding 后的维度对齐。
Q:π₀ 和 GR00T 的预训练权重来源不同?
π₀ 权重来自 openpi_pytorch 项目的转换版本,格式为
.safetensorsGR00T-N1.5 权重来自 NVIDIA 发布的 GR00T-N1.5-3B checkpoint
两者在 FluxVLA 中通过各自的 name_mapping 进行权重映射。
Q:如何从中断的训练恢复?
bash scripts/train.sh \
configs/pi0/pi0_paligemma_aloha_full_train.py \
work_dirs/pi0_paligemma_aloha_full_train \
--resume-from work_dirs/pi0_paligemma_aloha_full_train/checkpoint_epoch_2.pt
Q:如何覆盖配置参数?
bash scripts/train.sh \
configs/pi0/pi0_paligemma_libero_10_full_finetune_pytorch.py \
work_dirs/pi0_custom \
--cfg-options runner.max_epochs=50 runner.learning_rate=1e-5