PyTorch Lightning is a high-level PyTorch wrapper that organizes PyTorch code to remove boilerplate, enforce best practices, and enable scalable training. Built on top of PyTorch, it abstracts distributed training, mixed precision, callbacks, logging, and more while maintaining full control over the training loop. The framework is designed for researchers who need production-grade code without sacrificing flexibility—you write the research code in a LightningModule, and Lightning handles the engineering complexity. The key insight: Lightning doesn't abstract your PyTorch code; it structures it, making models reproducible, shareable, and scalable from laptop to supercomputer with minimal code changes.
What This Cheat Sheet Covers
This topic spans 20 focused tables and 112 indexed concepts. Below is a complete table-by-table outline of this topic, spanning foundational concepts through advanced details.
Table 1: LightningModule Core Methods
The LightningModule is where all your research code lives, and these methods are the hooks Lightning calls at each phase of the loop. Define your architecture in __init__, return a loss from training_step, hand back optimizers from configure_optimizers, and Lightning wires the rest of the engineering together—master these seven and you can build almost any model.
| Method | Example | Description |
|---|---|---|
def __init__(self, lr=1e-3): super().__init__() self.save_hyperparameters() self.model = nn.Linear(10, 1) | • Define model architecture, loss functions, and hyperparameters • call self.save_hyperparameters() to auto-save all __init__ arguments to self.hparams for checkpointing | |
def forward(self, x): return self.model(x) | • Standard PyTorch forward pass for inference • called by predict_step by default and can be used independently of training | |
def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.mse_loss(y_hat, y) self.log('train_loss', loss) return loss | • Computes loss for a single training batch • must return loss tensor for automatic optimization • use self.log() to track metrics | |
def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.mse_loss(y_hat, y) self.log('val_loss', loss) | • Evaluates model on validation data • no backward pass or optimizer step needed • log validation metrics with self.log(). |