Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP * removing artifacts * adding error * adding adapter check * linting * simplifying check * linting v2 * config fix -___-
This commit is contained in:
@@ -55,7 +55,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true
|
use_reentrant: false
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -1679,6 +1679,30 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_rl_config_gradient_checkpointing(cls, data):
|
||||||
|
# TODO: SalmanMohammadi
|
||||||
|
# Distributed RL with QLoRA + gradient checkpointing
|
||||||
|
# and use_reentrant = True is broken upstream in TRL
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
if (
|
||||||
|
data.get("rl")
|
||||||
|
and data.get("gradient_checkpointing")
|
||||||
|
and data.get("gradient_checkpointing_kwargs")
|
||||||
|
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
||||||
|
and data.get("load_in_4bit")
|
||||||
|
and data.get("adapter") == "qlora"
|
||||||
|
and data.get("capabilities")
|
||||||
|
and data.get("capabilities").get("n_gpu", 1) > 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The `use_reentrant: True` implementation of gradient checkpointing "
|
||||||
|
"is not supported for distributed RL training with QLoRA. Please set "
|
||||||
|
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_kto_config(cls, data):
|
def check_kto_config(cls, data):
|
||||||
@@ -1689,15 +1713,6 @@ class AxolotlInputConfig(
|
|||||||
if data.get("remove_unused_columns") is not False:
|
if data.get("remove_unused_columns") is not False:
|
||||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||||
|
|
||||||
if data.get("gradient_checkpointing") and not (
|
|
||||||
data.get("gradient_checkpointing_kwargs")
|
|
||||||
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
|
||||||
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user