Changes in PyTorch Lightning 1.4
In PyTorch Lightning, the order of hooks has changed. Specifically, in PL 1.3.8
on_train_epoch_start
training_step
training_step
training_step
training_step
training_epoch_end
on_epoch_end
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end
In PL1.4, it is
on_train_epoch_start
training_step
training_step
training_step
training_step
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end
training_epoch_end
on_epoch_end
You might be wondering, so what is the issue?
Well, I used to take the average of the train loss and print it to the console. It could be done via
def training_epoch_end(self, outputs):
self.avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean().item()
since the outputs
contained the loss from all the training batches. Also, it was executed after the last training_step
. Now that it has moved after the validation_epoch_end
hook, this can’t be done!
My Fix?
Add these lines to get a workaround
def on_train_epoch_start(self):
self.train_loss = torch.tensor([])
def on_validation_epoch_start(self):
self.avg_train_loss = self.train_loss.mean().item()
self.train_loss = torch.tensor([])
Moreover, whatever is now intended via training_epoch_end
can actually be acheived via the on_epoch_end
hook.