diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 8fc498cff..ff41d6713 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -18,7 +18,6 @@ ORIGINAL_CONTEXT_CODE = """ loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) del inputs if ( @@ -37,12 +36,16 @@ ORIGINAL_CONTEXT_CODE = """ torch.mps.empty_cache() else: torch.cuda.empty_cache() + kwargs = {} + # For LOMO optimizers you need to explicitly use the learnign rate if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: kwargs["learning_rate"] = self._get_learning_rate() + if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training + if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() @@ -73,12 +76,16 @@ PATCHED_CONTEXT_CODE = """ torch.mps.empty_cache() else: torch.cuda.empty_cache() + kwargs = {} + # For LOMO optimizers you need to explicitly use the learnign rate if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: kwargs["learning_rate"] = self._get_learning_rate() + if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training + if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward()