Shortcuts

Data Transforms

In this tutorial, we introduce the design of transforms pipeline in MMGeneration.

The structure of this guide is as follows:

Design of Data pipelines

Following typical conventions, we use Dataset and DataLoader for data loading with multiple workers. Dataset returns a dict of data items corresponding the arguments of models’ forward method.

In 1.x version of MMGeneration, all data transformations are inherited from BaseTransform. The input and output types of transformations are both dict. A simple example is as follow:

>>> from mmgen.datasets.transforms import LoadPairedImageFromFile
>>> transforms = LoadPairedImageFromFile(
>>>     key='pair',
>>>     domain_a='horse',
>>>     domain_b='zebra',
>>>     flag='color'),
>>> data_dict = {'pair_path': './data/pix2pix/facades/train/1.png'}
>>> data_dict = transforms(data_dict)
>>> print(data_dict.keys())
dict_keys(['pair_path', 'pair', 'pair_ori_shape', 'img_mask', 'img_photo', 'img_mask_path', 'img_photo_path', 'img_mask_ori_shape', 'img_photo_ori_shape'])

Generally, the last step of the transforms pipeline must be PackGenInputs. PackGenInputs will pack the processed data into a dict containing two fields: inputs and data_samples. inputs is the variable you want to use as the model’s input, which can be the type of torch.Tensor, dict of torch.Tensor, or any type you want. data_samples is a list of GenDataSample. Each GenDataSample contains groundtruth and necessary information for corresponding input.

Here is a pipeline example for Pix2Pix training on aerial2maps dataset.

source_domain = 'aerial'
target_domain = 'map'

pipeline = [
    dict(
        type='LoadPairedImageFromFile',
        io_backend='disk',
        key='pair',
        domain_a=domain_a,
        domain_b=domain_b,
        flag='color'),
    dict(
        type='TransformBroadcaster',
        mapping={'img': [f'img_{domain_a}', f'img_{domain_b}']},
        auto_remap=True,
        share_random_params=True,
        transforms=[
            dict(
                type='mmgen.Resize', scale=(286, 286),
                interpolation='bicubic'),
            dict(type='mmgen.FixedCrop', crop_size=(256, 256))
        ]),
    dict(
        type='Flip',
        keys=[f'img_{domain_a}', f'img_{domain_b}'],
        direction='horizontal'),
    dict(
        type='PackGenInputs',
        keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'],
        meta_keys=[
            'pair_path', 'sample_idx', 'pair_ori_shape',
            f'img_{domain_a}_path', f'img_{domain_b}_path',
            f'img_{domain_a}_ori_shape', f'img_{domain_b}_ori_shape', 'flip',
            'flip_direction'
        ])
]

Customization data transformation

The customized data transformation must inherinted from BaseTransform and implement transform function. Here we use a simple flipping transformation as example:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

Thus, we can instantiate a MyFlip object and use it to process the data dict.

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

Or, we can use MyFlip transformation in data pipeline in our config file.

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

Note that if you want to use MyFlip in config, you must ensure the file containing MyFlip is imported during the program run.

Read the Docs v: 1.x
Versions
latest
stable
1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.