Profiling PyTorch Lightning in SageMaker

This post will help you to profile your training loop written in PyTorch Lightning in SageMaker and view it in TensorBoard.

Code

Import TensorBoardLogger to log scalars and histograms along with the model graph.

from pytorch_lightning.loggers import TensorBoardLogger

Create a callback class for the profiler

class TorchTensorboardProfilerCallback(Callback):
    """Quick-and-dirty Callback for invoking TensorboardProfiler during training.
    
    For greater robustness, extend the pl.profiler.profilers.BaseProfiler. See
    https://pytorch-lightning.readthedocs.io/en/stable/advanced/profiler.html"""

    def __init__(self, profiler):
        super().__init__()
        self.profiler = profiler 

    def on_train_batch_end(self, trainer, pl_module, outputs, *args, **kwargs):
        self.profiler.step()
        

To log the model graph, inside your LightningModule, create an example input (and pass log_graph=True to the TensorBoardLogger object):

self.example_input_array = torch.rand((1,3,32,32))

and simply wrap your Trainer inside the torch.profiler.profile context manager.

The resulting files will be compressed and uploaded to the s3 bucket.

For the complete code, visit the repo.


Author | MMG

Learning...