adding error
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -1678,6 +1679,29 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_rl_config_gradient_checkpointing(cls, data):
|
||||
# TODO: SalmanMohammadi
|
||||
# Distributed RL with QLoRA + gradient checkpointing
|
||||
# and use_reentrant = True is broken upstream in TRL
|
||||
if (
|
||||
data.get("rl")
|
||||
and data.get("gradient_checkpointing")
|
||||
and data.get("gradient_checkpointing_kwargs")
|
||||
and "use_reentrant" in data.get("gradient_checkpointing_kwargs")
|
||||
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
||||
and data.get("load_in_4bit")
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("n_gpu", 1) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"The `use_reentrant: True` implementation of gradient checkpointing "
|
||||
"is not supported for distributed RL training with QLoRA. Please set "
|
||||
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_kto_config(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user