fix patch

This commit is contained in:
Wing Lian
2025-01-13 14:02:27 -05:00
parent 5b5ba49c46
commit 4cc89f73f0

View File

@@ -18,7 +18,6 @@ ORIGINAL_CONTEXT_CODE = """
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@@ -37,12 +36,16 @@ ORIGINAL_CONTEXT_CODE = """
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
@@ -73,12 +76,16 @@ PATCHED_CONTEXT_CODE = """
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()