数据集接口#

本页整理训练/推理数据接口的文档化约定,便于按字段与 transform 快速定位配置项。

用途#

  • 定义训练数据来源(如 ParquetDataset

  • 定义字段映射与统计键

  • 组织 transforms 流水线

  • 保持训练与推理的数据处理一致性

核心参数#

train_dataloader#

字段

说明

per_device_batch_size

每个设备 batch size

per_device_num_workers

每个设备数据加载进程数

dataset.type

数据集封装类型(如 DistributedRepeatingDataset

dataset.name_mappings

字段名映射

dataset.statistic_keys

统计字段

datasets[].data_root_path

数据路径列表

datasets[].transforms

数据处理流水线

高关联 transform 字段#

transform

常见字段

ProcessParquetInputs

parquet_keysvideo_keysembodiment_id

ProcessPromptsWithImage

max_lennum_imagestokenizer

ResizeImages

heightwidth

NormalizeImages

meansstds

NormalizeStatesAndActions

state_dimaction_dimnorm_type

最小示例#

train_dataloader = dict(
    per_device_batch_size=8,
    per_device_num_workers=4,
    dataset=dict(
        type='DistributedRepeatingDataset',
        datasets=[
            dict(
                type='ParquetDataset',
                data_root_path=['./datasets/your_dataset'],
                transforms=[
                    dict(type='ProcessParquetInputs'),
                    dict(type='ProcessPromptsWithImage', num_images=3),
                    dict(type='ResizeImages', height=224, width=224),
                    dict(type='NormalizeImages'),
                    dict(type='NormalizeStatesAndActions', state_dim=64, action_dim=32)
                ])
        ]))

关联教程#