Data Configuration#

Data configuration defines the source, format, and preprocessing of training data. It is configured via the train_dataloader dict. The core fields include:

  • per_device_batch_size: Batch size per GPU device, affecting memory usage and training speed

  • per_device_num_workers: Number of data loading workers per device

  • dataset: Dataset configuration with the following sub-fields:

    • type: Dataset type, e.g., DistributedRepeatingDataset for distributed training with repeated sampling

    • name_mappings: Field name mappings to unify key names across different data sources

    • statistic_keys: Keys for which statistics are computed (used for normalization parameters)

    • datasets: List of datasets, supporting multi-dataset mixed training

  • transforms: Data preprocessing pipeline, applied in sequence

Below is a complete data configuration example:

train_dataloader = dict(
    per_device_batch_size=8,
    per_device_num_workers=4,
    dataset=dict(
        type='DistributedRepeatingDataset',
        name_mappings={
            'observation.state': ['proprio', 'action']
        },
        statistic_keys=[
            'observation.state',
            'observation.eepose',
            'timestamp'
        ],
        datasets=[
            dict(
                type='ParquetDataset',
                data_root_path=[
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot/20251030_20251030_01_4090',
                    '/path/to/data/RealRobot_AgileX_aloha_lerobot/20251031_20251031_01_4090',
                ],
                transforms=[
                    dict(
                        type='ProcessParquetInputs',
                        parquet_keys=[
                            'observation.state',
                            'timestamp',
                            'actions',
                            'info',
                            'stats',
                            'action_masks'
                        ],
                        video_keys=[
                            'observation.images.cam_high',
                            'observation.images.cam_left_wrist',
                            'observation.images.cam_right_wrist',
                        ],
                        name_mappings={
                            'observation.state': ['states'],
                            'actions': ['actions']
                        }),
                    dict(
                        type='NormalizeStatesAndActions',
                        action_dim=32,
                        state_dim=32,
                        state_key='proprio',
                        action_key='action',
                        norm_type='min_max'),
                    dict(type='PreparePromptWithState'),
                    dict(
                        type='ProcessPrompts',
                        max_len=200,
                        tokenizer=dict(
                            type='PretrainedTokenizer',
                            model_path='/path/to/checkpoints/paligemma-3b-pt-224',
                        )),
                    dict(type='ResizeImages', height=224, width=224),
                    dict(type='SimpleNormalizeImages'),
                ],
                action_window_size=50)
        ]))