WritingLightning AILightning AIpublished Oct 19, 2023seen 5d

Step-By-Step Walk-Through of Pytorch Lightning

Open original ↗

Captured source

source ↗
published Oct 19, 2023seen 5dcaptured 3dhttp 200method plain

Step-By-Step Walk-Through of Pytorch Lightning - Lightning AI Lightning AI Studios: Never set up a local environment again →

Takeaways Learn step-by-step how to train a Convolutional Neural Network for Image Classification on CIFAR-10 dataset using PyTorch Lightning with callbacks and loggers for monitoring model performance. In this blog, you will learn about the different components of PyTorch Lightning and how to train an image classifier on the CIFAR-10 dataset with PyTorch Lightning. We will also discuss how to use loggers and callbacks like Tensorboard, ModelCheckpoint, etc. PyTorch Lightning is a high-level wrapper over PyTorch which makes model training easier and scalable by removing all the boilerplates so that you can focus more on the experiments and research than engineering the model training process. PyTorch Lightning is a great way to start with deep learning for beginners as well as for experts who want to scale their training to billion+ parameter models like Llama and Stable Diffusion. We will begin by acquainting ourselves with the key components of PyTorch Lightning, subsequently utilizing this knowledge to train an image classification model. Additionally, we will document our experiments using a logger such as Tensorboard to monitor and visualize the metrics. You can access the code used for this blog here . Components of PyTorch Lightning

PyTorch Lightning consists of two primary components: LightningModule , and Trainer . These modules play a crucial role in organizing and automating various aspects and phases of the model training lifecycle. Let’s delve into each of them step by step. ⚡ LightningModule – Organizes the Training Loop

LightningModule contains all the logic for model initialization, training/validation steps, and the calculation of loss and accuracy metrics. It organizes the PyTorch code into six sections: The LightningModule comprises Initialization ( __init__  and  setup() ) Train logic ( training_step() ) Validation loop ( validation_step() ) Test logic ( test_step() ) Prediction logic ( predict_step() ) Optimizers and LR Schedulers ( configure_optimizers() )

The example below shows a sample implementation of the LightningModule . import torch import torch.nn as nn import pytorch_lightning as pl

class MyLitModel(pl.LightningModule): def __init__(self): super().__init__() self.model = load_model(...) self.loss_fn = nn.CrossEntropyLoss()

def forward(self, x): return self.model(x)

def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) self.log('train_loss', loss) return loss

def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) In this sample example, the model is initialized with the __init__ method, and we define the training_step , which takes the batch and batch_idx arguments. We separate the inputs x and labels y from the batch, pass the inputs through the model, and calculate the cross-entropy loss. PyTorch Lightning will automatically call loss.backward() and update the Adam optimizer that we have defined in the configure_optimizers method. You don’t need to manually move the tensors from the CPU to the GPU; Lightning ⚡ will take care of that for you. Lightning Trainer – Automating the Training Process

Once we have organized our training code with the LightningModule and loaded the dataset, we are all set to begin the training process using the Lightning Trainer. It simplifies mixed precision training, GPU selection, setting the number of devices, distributed training, and much more. The Trainer class  has 35 flags (at the time of writing this blog) that can be used for various tasks, ranging from defining the number of epochs to scaling the training for models with billions of parameters .

💡 PyTorch Lightning also offers LightningDataModule that can be used to organize the PyTorch dataset and dataloaders . It also automates the data loading in a distributed training environment. In this blog, we won’t discuss datamodules as you can also use the DataLoader directly but I would encourage the readers to read from the official docs here .

Let’s explore how to use the Lightning Trainer with a LightningModule and go through a few of the flags using the example below. We create a Lightning Trainer object with 4 GPUs, perform mixed-precision training with the float16 data type, and finally train the MyLitModel model that we defined in the previous section. Finally, we initiate the training by providing the model and dataloaders to the trainer.fit method. trainer = pl.Trainer( devices=4, accelerator="gpu", precision="fp16-mixed", )

model = MyLitModel() trainer.fit(model, train_dataloder=train_dataloder) Loggers and Callbacks You can also add a logger, such as Tensorboard, WandB, Comet, or a simple CSVLogger, to monitor the loss or any other metrics that you’ve logged during training. For simplicity, we will use Tensorboard in this blog. You can just import the TensorBoardLogger and add it to the Trainer as shown below: from pytorch_lightning.loggers import TensorBoardLogger

trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="logs/"))

trainer.fit(model, train_dataloader, val_dataloader) To start the Tensorboard web UI, run the command tensorboard --logdir logs/ from your terminal, and it will launch the Tensorboard UI on the default port 6006.

PyTorch Lightning provides several built-in callbacks , such as BatchSizeFinder , EarlyStopping , ModelCheckpoint , and more. These callbacks offer valuable additional functionality to manage and manipulate training at various stages of the loop. In this blog, we will use the EarlyStopping callback to automatically stop our training once the monitored metric (e.g., validation loss) stops improving. You can also configure other arguments, such as patience , to determine the number of checks before training should stop. from pytorch_lightning.callbacks import EarlyStopping

early_stopping = EarlyStopping('val_loss', patience=7) trainer = pl.Trainer(callbacks=early_stopping) trainer.fit(model, train_dataloader, val_dataloader) Training an Image Classifier (Convolutional Neural Networks or CNN) on the CIFAR-10 dataset using PyTorch Lightning Now that we have learned about LightningModule and Trainer, we will proceed to train an image classification model using the CIFAR-10 dataset. We will begin by loading the dataset from torchvision, defining the model, training_step, and…

Excerpt shown — open the source for the full document.