数据配置#

数据配置定义了训练数据的来源、格式及预处理方式。我们通过 train_dataloader 字典进行配置,核心内容包括:

  • per_device_batch_size:每个 GPU 设备的批量大小,影响显存占用和训练速度

  • per_device_num_workers:每个设备的数据加载进程数

  • dataset:数据集配置,包含以下子项:

    • type:数据集类型,如 DistributedRepeatingDataset 支持分布式训练的重复采样

    • name_mappings:字段名称映射,用于统一不同数据源的键名

    • statistic_keys:需要统计的键(用于计算归一化参数)

    • datasets:数据集列表,支持多数据集混合训练

  • transforms:数据预处理流水线,按顺序执行变换

以下是一个完整的数据配置示例:

train_dataloader = dict(
    per_device_batch_size=8,
    per_device_num_workers=4,
    dataset=dict(
        type='DistributedRepeatingDataset',
        name_mappings={
            'observation.state': ['proprio', 'action']
        },
        statistic_keys=[
            'observation.state',
            'observation.eepose',
            'timestamp'
        ],
        datasets=[
            dict(
                type='ParquetDataset',
                data_root_path=[
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot/20251030_20251030_01_4090',
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot/20251031_20251031_01_4090',
                ],
                transforms=[
                    dict(
                        type='ProcessParquetInputs',
                        parquet_keys=[
                            'observation.state',
                            'timestamp',
                            'actions',
                            'info',
                            'stats',
                            'action_masks'
                        ],
                        video_keys=[
                            'observation.images.cam_high',
                            'observation.images.cam_left_wrist',
                            'observation.images.cam_right_wrist',
                        ],
                        name_mappings={
                            'observation.state': ['states'],
                            'actions': ['actions']
                        }),
                    dict(
                        type='NormalizeStatesAndActions',
                        action_dim=32,
                        state_dim=32,
                        state_key='proprio',
                        action_key='action',
                        norm_type='min_max'),
                    dict(type='PreparePromptWithState'),
                    dict(
                        type='ProcessPrompts',
                        max_len=200,
                        tokenizer=dict(
                            type='PretrainedTokenizer',
                            model_path='/path/to/checkpoints/paligemma-3b-pt-224',
                        )),
                    dict(type='ResizeImages', height=224, width=224),
                    dict(type='SimpleNormalizeImages'),
                ],
                action_window_size=50)
        ]))