Saving and Loading Torch Models

I generally write my DL models (referred to as ‘base model’ henceforth) in PyTorch (i.e. subclassed from torch.nn.Module) and pass them as an argument to my Pytorch Lightning(PTL)-based training module (referred as Model) (subclassed from pl.LightningModule).

Let’s say we have deployed our model on the cloud, and we want to run inference. We must define our base model. But, we don’t want to initialize/write the code for the training loop (PTL) over again just to read the trained model’s weights.

Please see the accompanying code for the complete example.

Open Training Notebook In Colab: Open Training Notebook In Colab

Open Loading Notebook In Colab: Open Training Notebook In Colab

The following code segment will help you to save the best model using the ModelCheckpoint callback.

Training

checkpoint_callback = ModelCheckpoint(
    monitor='train_acc',
    dirpath='./ckpt',
    filename='model-{epoch_num:.0f}-{val_loss:.2f}',
    mode='max'
)

This will save a ‘.ckpt’ file in the folder ‘ckpt’.

Inference

After defining the model, we run the following code. It simply copies the state dictionary (values) from the pretrained model (read from disk) and copies it to our new (randomly initialized) model. Unfortunately, the two have different ‘keys’ since the former uses the base model as a local object.

root='./ckpt/'
ckpt = os.listdir(root)[0]
pre_trained_model=torch.load('./ckpt/' + ckpt)
base_model_new = CNN()
print(f'Initial State: {base_model_new.state_dict()["fc2.bias"]}')
my_model_kvpair=base_model_new.state_dict()
for key,value in pre_trained_model['state_dict'].items():
    my_key = key[6:]
    my_model_kvpair[my_key] = pre_trained_model['state_dict'][key]
base_model_new.load_state_dict(my_model_kvpair)
print(f'After Loading: {base_model_new.state_dict()["fc2.bias"]}')

Hope it helps…


Author | MMG

Learning...