PyTorch custom Datasets, DataLoaders, and Transforms

This article comes from:

We usually make great efforts to prepare data when solving any machine learning problem. PyTorch provides many tools to simplify data loading and hopefully make your code more readable. In this tutorial, we'll learn how to load and preprocess / enhance data from non trivial datasets.

To run this tutorial, make sure these packages are installed:

Scikit image: input / output (IO) and transformations for images
 pandas: for simpler parsing of csv files
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings

plt.ion()   # interactive mode 

The dataset we are dealing with is the facial pose, which means that a face will be labeled as follows:

In general, 68 different landmark points are marked on each face.


Download the dataset from here so that the directory structure of image data is as follows: 'data/faces /'. This dataset is actually generated by using dlib's attitude estimation. The images used are from several images marked "face" in the imagenet.

The dataset comes with a csv file, in which annotations are stored, like this:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

Let's quickly read the CSV file and get the annotation information, and save it to an array of (N, 2), where N is the number of landmarks.

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

Let's write a simple helper function to display an image and its annotations, and use it to display an example.

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

show_landmarks(io.imread(os.path.join('data/faces/', img_name)),

Class Dataset is an abstract class that represents a Dataset. Your custom Dataset class should inherit the Dataset and override the following methods:

__Len? So that len(dataset) can return the size of the dataset.
__getitem is used to support indexes such as dataset[i] to get the I

One sample.

Let's create a DataSet class for our face landmark DataSet. We will read the CSV in init, but leave the reading of the image to getitem. This is memory efficient because all images are not stored in memory at once, but read as needed.

The sample of our dataset will be a dictionary, such as {'image': image, 'landmarks': Landmarks}. Our dataset will accept an optional parameter transform so that any required data preprocessing can be applied to the sample. We'll see the use of transform in the next section.

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

Let's instantiate this class and iterate over the data samples. We will print the size of the first four samples and display their landmarks.

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    ax.set_title('Sample #{}'.format(i))

    if i == 3:


From the above we can see that the sample size is different. Most neural networks expect fixed size images. Therefore, we need to write some preprocessing code. Let's create three transformations:

Rescale: scale image
 RandomCrop: randomly cropped image. Used for data augmentation
 To tensor: convert numpy images to torch images

We will write them as callable classes instead of simple functions, so that we do not need to pass the parameters of the transformation every time we call transform. To do this, we only need to implement the call method and, if necessary, the init method. Then we can use this transformation:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

Let's see how these transformations can be applied to images and landmarks.

class Rescale(object):
    """Zoom the image to a given size.

        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
                new_h, new_w = self.output_size, self.output_size * w / h
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

class RandomCrop(object):
    """Randomly cut an image on an image sample.

        output_size (tuple or int): Desired output size. If int, square crop is made.

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}

class ToTensor(object):
    """Sample ndarrays Convert to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

Compound converter

Now we apply the transformation to an example.

For example, we want to restore the shorter part of the image to 256, and then randomly cut out a square image of size 224 from the image. In other words, we want to synthesize Rescale and RandomCrop transforms. torchvision.transforms.Compose is a simple callable class that allows us to do so.

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),

# Apply each of the above transformations to the example
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)

Iterate over a dataset

Let's put all of this together to create a dataset with composite transformations. In summary, each time the dataset is sampled

Dynamically read an image from a file.
Transforms is applied to read out images.
Because one of the transformations is random, data is added during sampling.

We can iterate over the data set created using the for i in range loop as before.

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:

However, by iterating over the data using a simple for loop, we have lost a lot of features. In particular, we missed:

Batch data
 Randomly scrambling data
 Use multiprocessing workers to load data in parallel is an iterator that provides the above features. We should be clear about the parameters used below. One of the interesting parameters is collate FN. You can use collate FN to explicitly specify how samples will be batched. However, the default settings are adequate for most situations.

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

# Display auxiliary functions of a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),

    # observe 4th batch and stop.
    if i_batch == 3:

Postscript: torch vision

In this tutorial, we've learned how to write and use dataset classes, converter classes, and data loader classes. The torch vision package provides some common data set classes and converter classes. You may not even have to write custom classes. One of the more common datasets available in torch vision is ImageFolder. It assumes that the image is organized as follows:


Where 'ants' and' bees' etc. are class labels. Similarly, the general transformation of images of PIL.Image type, such as random horizontalflip and scale, is also available. You can use these to write a dataloader, like this:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
dataset_loader =,
                                             batch_size=4, shuffle=True,
90 original articles published, 89 praised, 270000 visitors+
Private letter follow

Posted on Wed, 11 Mar 2020 02:09:12 -0700 by CodeToad