数据配置#
数据配置定义了训练数据的来源、格式及预处理方式。我们通过 train_dataloader 字典进行配置,核心内容包括:
per_device_batch_size:每个 GPU 设备的批量大小,影响显存占用和训练速度
per_device_num_workers:每个设备的数据加载进程数
dataset:数据集配置,包含以下子项:
type:数据集类型,如
DistributedRepeatingDataset支持分布式训练的重复采样name_mappings:字段名称映射,用于统一不同数据源的键名
statistic_keys:需要统计的键(用于计算归一化参数)
datasets:数据集列表,支持多数据集混合训练
transforms:数据预处理流水线,按顺序执行变换
以下是一个完整的数据配置示例:
train_dataloader = dict(
per_device_batch_size=8,
per_device_num_workers=4,
dataset=dict(
type='DistributedRepeatingDataset',
name_mappings={
'observation.state': ['proprio', 'action']
},
statistic_keys=[
'observation.state',
'observation.eepose',
'timestamp'
],
datasets=[
dict(
type='ParquetDataset',
data_root_path=[
'/path/to/data/RealRobot_AgileX_aloha_lerobot/20251030_20251030_01_4090',
'/path/to/data/RealRobot_AgileX_aloha_lerobot/20251031_20251031_01_4090',
],
transforms=[
dict(
type='ProcessParquetInputs',
parquet_keys=[
'observation.state',
'timestamp',
'actions',
'info',
'stats',
'action_masks'
],
video_keys=[
'observation.images.cam_high',
'observation.images.cam_left_wrist',
'observation.images.cam_right_wrist',
],
name_mappings={
'observation.state': ['states'],
'actions': ['actions']
}),
dict(
type='NormalizeStatesAndActions',
action_dim=32,
state_dim=32,
state_key='proprio',
action_key='action',
norm_type='min_max'),
dict(type='PreparePromptWithState'),
dict(
type='ProcessPrompts',
max_len=200,
tokenizer=dict(
type='PretrainedTokenizer',
model_path='/path/to/checkpoints/paligemma-3b-pt-224',
)),
dict(type='ResizeImages', height=224, width=224),
dict(type='SimpleNormalizeImages'),
],
action_window_size=50)
]))