Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c36a6fea6 | ||
|
|
64aca3c23c | ||
|
|
22abfd6170 | ||
|
|
0658c458b7 | ||
|
|
690908cf2f | ||
|
|
b9378e9b39 | ||
|
|
57b0ad1467 | ||
|
|
ec4ead6e3e | ||
|
|
a319ac7d3e | ||
|
|
09d3f2cffa |
@@ -55,7 +55,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -1679,6 +1679,30 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_rl_config_gradient_checkpointing(cls, data):
|
||||
# TODO: SalmanMohammadi
|
||||
# Distributed RL with QLoRA + gradient checkpointing
|
||||
# and use_reentrant = True is broken upstream in TRL
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
data.get("rl")
|
||||
and data.get("gradient_checkpointing")
|
||||
and data.get("gradient_checkpointing_kwargs")
|
||||
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
||||
and data.get("load_in_4bit")
|
||||
and data.get("adapter") == "qlora"
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("n_gpu", 1) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"The `use_reentrant: True` implementation of gradient checkpointing "
|
||||
"is not supported for distributed RL training with QLoRA. Please set "
|
||||
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_kto_config(cls, data):
|
||||
@@ -1689,15 +1713,6 @@ class AxolotlInputConfig(
|
||||
if data.get("remove_unused_columns") is not False:
|
||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||
|
||||
if data.get("gradient_checkpointing") and not (
|
||||
data.get("gradient_checkpointing_kwargs")
|
||||
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
||||
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
||||
):
|
||||
raise ValueError(
|
||||
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user