From c7f1c191a31bd151b04d1d741ff2f5bdc4a08c73 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Apr 2025 15:18:56 -0400 Subject: [PATCH] additional validation for fsdp2, bump dep versions --- requirements.txt | 6 +++--- src/axolotl/monkeypatch/attention/flex_attn.py | 1 + src/axolotl/utils/schemas/config.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7be0f7e01..d82489203 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,12 +12,12 @@ liger-kernel==0.5.5 packaging==23.2 peft==0.15.0 -transformers @ git+https://github.com/huggingface/transformers.git@a4e55fcff8d980eab0c9cf9e51ca13460437e1c7 +transformers==4.51.0 tokenizers>=0.21.1 accelerate==1.6.0 datasets==3.5.0 -deepspeed==0.15.4 -trl==0.16.0 +deepspeed>=0.15.4 +trl==0.16.1 optimum==1.16.2 hf_transfer diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 4098c8c1c..2ca5b09a6 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -8,6 +8,7 @@ import transformers def patch_flex_wrapper(): + # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") is_transformers_below_4_51 = transformers.__version__ < "4.51.0" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6bc3519b6..3ceae4273 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -955,6 +955,18 @@ class AxolotlInputConfig( raise ValueError( f"FSDP Offload not compatible with {data.get('optimizer')}" ) + if ( + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and str(data["fsdp_config"].get("fsdp_version")) == "2" + ): + if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: + # CUDA ops errors with bnb 8bit optimizer + FSDP2 + raise ValueError( + f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" + ) + return data @model_validator(mode="before")