Changes in PyTorch Lightning 1.4

In PyTorch Lightning, the order of hooks has changed. Specifically, in PL 1.3.8

on_train_epoch_start
training_step
training_step
training_step
training_step
training_epoch_end
on_epoch_end
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end

In PL1.4, it is

on_train_epoch_start
training_step
training_step
training_step
training_step
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end
training_epoch_end
on_epoch_end

You might be wondering, so what is the issue?

Well, I used to take the average of the train loss and print it to the console. It could be done via

def training_epoch_end(self, outputs):
    self.avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean().item()

since the outputs contained the loss from all the training batches. Also, it was executed after the last training_step. Now that it has moved after the validation_epoch_end hook, this can’t be done!

My Fix?
Add these lines to get a workaround

def on_train_epoch_start(self):
    self.train_loss = torch.tensor([])

def on_validation_epoch_start(self):
    self.avg_train_loss = self.train_loss.mean().item()
    self.train_loss = torch.tensor([])

Moreover, whatever is now intended via training_epoch_end can actually be acheived via the on_epoch_end hook.


Author | MMG

Learning...