Profiling PyTorch Lightning in SageMaker
This post will help you to profile your training loop written in PyTorch Lightning in SageMaker and view it in TensorBoard.
Code
Import TensorBoardLogger
to log scalars and histograms along with the model graph.
from pytorch_lightning.loggers import TensorBoardLogger
Create a callback class for the profiler
class TorchTensorboardProfilerCallback(Callback):
"""Quick-and-dirty Callback for invoking TensorboardProfiler during training.
For greater robustness, extend the pl.profiler.profilers.BaseProfiler. See
https://pytorch-lightning.readthedocs.io/en/stable/advanced/profiler.html"""
def __init__(self, profiler):
super().__init__()
self.profiler = profiler
def on_train_batch_end(self, trainer, pl_module, outputs, *args, **kwargs):
self.profiler.step()
To log the model graph, inside your LightningModule, create an example input (and pass log_graph=True
to the TensorBoardLogger object):
self.example_input_array = torch.rand((1,3,32,32))
and simply wrap your Trainer inside the torch.profiler.profile
context manager.
The resulting files will be compressed and uploaded to the s3 bucket.
For the complete code, visit the repo.