adding error

This commit is contained in:
Salman Mohammadi
2025-03-18 11:20:34 +00:00
parent a319ac7d3e
commit ec4ead6e3e

View File

@@ -1,4 +1,5 @@
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging
@@ -1678,6 +1679,29 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_rl_config_gradient_checkpointing(cls, data):
# TODO: SalmanMohammadi
# Distributed RL with QLoRA + gradient checkpointing
# and use_reentrant = True is broken upstream in TRL
if (
data.get("rl")
and data.get("gradient_checkpointing")
and data.get("gradient_checkpointing_kwargs")
and "use_reentrant" in data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
and data.get("load_in_4bit")
and data.get("capabilities")
and data.get("capabilities").get("n_gpu", 1) > 1
):
raise ValueError(
"The `use_reentrant: True` implementation of gradient checkpointing "
"is not supported for distributed RL training with QLoRA. Please set "
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
)
return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):