additional validation for fsdp2, bump dep versions

This commit is contained in:
Wing Lian
2025-04-06 15:18:56 -04:00
parent 1a5d445413
commit c7f1c191a3
3 changed files with 16 additions and 3 deletions

View File

@@ -12,12 +12,12 @@ liger-kernel==0.5.5
packaging==23.2
peft==0.15.0
transformers @ git+https://github.com/huggingface/transformers.git@a4e55fcff8d980eab0c9cf9e51ca13460437e1c7
transformers==4.51.0
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
deepspeed==0.15.4
trl==0.16.0
deepspeed>=0.15.4
trl==0.16.1
optimum==1.16.2
hf_transfer

View File

@@ -8,6 +8,7 @@ import transformers
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"

View File

@@ -955,6 +955,18 @@ class AxolotlInputConfig(
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
raise ValueError(
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
)
return data
@model_validator(mode="before")