additional validation for fsdp2, bump dep versions
This commit is contained in:
@@ -12,12 +12,12 @@ liger-kernel==0.5.5
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@a4e55fcff8d980eab0c9cf9e51ca13460437e1c7
|
||||
transformers==4.51.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
deepspeed==0.15.4
|
||||
trl==0.16.0
|
||||
deepspeed>=0.15.4
|
||||
trl==0.16.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
@@ -8,6 +8,7 @@ import transformers
|
||||
|
||||
|
||||
def patch_flex_wrapper():
|
||||
# TODO remove this patch when transformers#37285 is merged and in a release
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
||||
|
||||
|
||||
@@ -955,6 +955,18 @@ class AxolotlInputConfig(
|
||||
raise ValueError(
|
||||
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||
)
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and "8bit" in data.get("optimizer", "")
|
||||
and data.get("fsdp_config")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
||||
):
|
||||
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
|
||||
# CUDA ops errors with bnb 8bit optimizer + FSDP2
|
||||
raise ValueError(
|
||||
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user