fix patch
This commit is contained in:
@@ -18,7 +18,6 @@ ORIGINAL_CONTEXT_CODE = """
|
|||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
else:
|
else:
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
|
||||||
|
|
||||||
del inputs
|
del inputs
|
||||||
if (
|
if (
|
||||||
@@ -37,12 +36,16 @@ ORIGINAL_CONTEXT_CODE = """
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
else:
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
kwargs["learning_rate"] = self._get_learning_rate()
|
kwargs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
@@ -73,12 +76,16 @@ PATCHED_CONTEXT_CODE = """
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
else:
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
kwargs["learning_rate"] = self._get_learning_rate()
|
kwargs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user