diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index df4a08489..32168dc37 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -55,7 +55,7 @@ tf32: true gradient_checkpointing: true gradient_checkpointing_kwargs: - use_reentrant: true + use_reentrant: false early_stopping_patience: resume_from_checkpoint: local_rank: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 49f26ac0c..2fa86eced 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1679,6 +1679,30 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_rl_config_gradient_checkpointing(cls, data): + # TODO: SalmanMohammadi + # Distributed RL with QLoRA + gradient checkpointing + # and use_reentrant = True is broken upstream in TRL + # pylint: disable=too-many-boolean-expressions + if ( + data.get("rl") + and data.get("gradient_checkpointing") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs").get("use_reentrant") + and data.get("load_in_4bit") + and data.get("adapter") == "qlora" + and data.get("capabilities") + and data.get("capabilities").get("n_gpu", 1) > 1 + ): + raise ValueError( + "The `use_reentrant: True` implementation of gradient checkpointing " + "is not supported for distributed RL training with QLoRA. Please set " + "`use_reentrant: False` in `gradient_checkpointing_kwargs`." + ) + return data + @model_validator(mode="before") @classmethod def check_kto_config(cls, data): @@ -1689,15 +1713,6 @@ class AxolotlInputConfig( if data.get("remove_unused_columns") is not False: raise ValueError("Set `remove_unused_columns: False` when using kto") - if data.get("gradient_checkpointing") and not ( - data.get("gradient_checkpointing_kwargs") - and isinstance(data.get("gradient_checkpointing_kwargs"), dict) - and data["gradient_checkpointing_kwargs"].get("use_reentrant") - ): - raise ValueError( - "Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled" - ) - return data