Quick Tutorial on Distributed Data Parallel Training on PyTorch with Multi GPUs

A quick tutorial on distributed data parallel training on PyTorch with multiple GPUs to let beginners start training in just a few minutes.

quick tutorial on distributed data parallel training on pytorch on cloud hpc aws google cloud azure

DP vs. DDP

We know that PyTorch itself provides two implementations for multi-GPU training.

  • DataParallel (DP): Parameter Server mode, one card bit reducer, and super simple to implement, one line of code.

  • DistributedDataParallel (DDP): All-Reduce mode, intended for distributed training, but can also be used for training on a single node with multi cards.

DataParallel is an algorithm based on Parameter Server algorithm, which is relatively simple to implement by adding one line to the original standalone single card code:

model = nn.DataParallel(model, device_ids=config.gpu_id)

But its load imbalance problem is more serious, sometimes when the model is larger (such as bert-large), the reducer's card will have an extra 3-4G of GPU memory occupied.

And the speed is also slower:


The official recommendation is to use the new DDP with all-reduce algorithm, which was designed mainly for multi-node multi-GPU training, but it also works on a single-node multi-GPU training.

First, a few concepts need to be clarified.

  • rank

Multi-node/multi-GPU: represents a particular node

Single-node/multi-GPU Mode: represents a particular GPU

  • world_size

Multi-node/multi-GPU: represents how many nodes

Single-node/multi-GPU: represents how many GPUs there are

  • local_rank

Multi-node/multi-GPU: the number of a GPU

Single-node/multi-GPU: the number of a GPU

Single-node single-GPU training

Let's start by giving a demo of a single node with a single card training code and simply run the data stream. The demo is small but complete. This demo contains the complete steps of our usual deep learning training process. It includes the definition and instantiation of the model and dataset, the loss function, the definition of the optimizer, gradient clearing, gradient backpropagation, optimizer iteration update, and the printing of the training log.

Next, we will use the DistributedDataParallel provided by PyTorch to convert this single-computer, single-card training process to single-node, multi-GPU parallel training.

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default='0,2')
parser.add_argument('--batchSize', type=int, default=32)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--dataset-size', type=int, default=128)
parser.add_argument('--num-classes', type=int, default=10)
config = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id

# Define dataset sample 
class RandomDataset(Dataset):
    def __init__(self, dataset_size, image_size=32):
        images = torch.randn(dataset_size, 3, image_size, image_size)
        labels = torch.zeros(dataset_size, dtype=int)
        self.data = list(zip(images, labels))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# Define the model
class Model(nn.Module):
    def __init__(self, num_classes):
        super(Model, self).__init__()
        self.conv2d = nn.Conv2d(3, 16, 3)
        self.fc = nn.Linear(30*30*16, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.conv2d(x)
        x = x.reshape(batch_size, -1)
        x = self.fc(x)
        out = self.softmax(x)
        return out

# Instantiate models, datasets, loaders, and optimizers
model = Model(config.num_classes)
dataset = RandomDataset(config.dataset_size)
loader = DataLoader(dataset, batch_size=config.batchSize, shuffle=True)
loss_func = nn.CrossEntropyLoss()

if torch.cuda.is_available():
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)

# If using DP, only one line
# if torch.cuda.device_count > 1: model = nn.DataParallel(