From ec4ead6e3e2ffce8a11d04bd3e596d5cf2b1c88c Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 18 Mar 2025 11:20:34 +0000 Subject: [PATCH] adding error --- .../config/models/input/v0_4_1/__init__.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index cf4b708ae..cd7fd9bee 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,4 +1,5 @@ """Module with Pydantic models for configuration.""" + # pylint: disable=too-many-lines import logging @@ -1678,6 +1679,29 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_rl_config_gradient_checkpointing(cls, data): + # TODO: SalmanMohammadi + # Distributed RL with QLoRA + gradient checkpointing + # and use_reentrant = True is broken upstream in TRL + if ( + data.get("rl") + and data.get("gradient_checkpointing") + and data.get("gradient_checkpointing_kwargs") + and "use_reentrant" in data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs").get("use_reentrant") + and data.get("load_in_4bit") + and data.get("capabilities") + and data.get("capabilities").get("n_gpu", 1) > 1 + ): + raise ValueError( + "The `use_reentrant: True` implementation of gradient checkpointing " + "is not supported for distributed RL training with QLoRA. Please set " + "`use_reentrant: False` in `gradient_checkpointing_kwargs`." + ) + return data + @model_validator(mode="before") @classmethod def check_kto_config(cls, data):