upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0 * upgrade accelerate patch fix * add hints when using gradient_checkpointing with DPO * set gradient-checpointing properly
This commit is contained in:
@@ -435,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing:
|
||||
elif self.cfg.gradient_checkpointing is not None:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
@@ -1378,6 +1378,21 @@ class ComplexValidationMixin:
|
||||
|
||||
return self
|
||||
|
||||
def hint_gradient_checkpointing_dpo_lora_ddp(self):
|
||||
if (
|
||||
(self.gradient_checkpointing is True or self.gradient_checkpointing is None)
|
||||
and self.capabilities
|
||||
and self.capabilities.get("n_gpu", 1) > 1
|
||||
and self.adapter in ("lora", "qlora")
|
||||
and self.rl == RLType.DPO
|
||||
and not self.fsdp
|
||||
and not self.deepspeed
|
||||
):
|
||||
LOG.warning(
|
||||
"gradient_checkpointing with DPO + DDP + LoRA is not recommended."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class DistributedValidationMixin:
|
||||
"""validation for distributed training."""
|
||||
|
||||
Reference in New Issue
Block a user