diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index df4a08489..fbcf80e9b 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -55,7 +55,7 @@ tf32: true gradient_checkpointing: true gradient_checkpointing_kwargs: - use_reentrant: true + use_reentrant: fal early_stopping_patience: resume_from_checkpoint: local_rank: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 674876e52..fbe63b757 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1685,6 +1685,7 @@ class AxolotlInputConfig( # TODO: SalmanMohammadi # Distributed RL with QLoRA + gradient checkpointing # and use_reentrant = True is broken upstream in TRL + # pylint: disable=too-many-boolean-expressions if ( data.get("rl") and data.get("gradient_checkpointing")