From 811224d7b7847811ddb4803bc11db2355cbd3df1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Dec 2024 11:34:06 -0500 Subject: [PATCH] =?UTF-8?q?broken=20=F0=9F=A6=A5=20with=20latest=20transfo?= =?UTF-8?q?rmers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/axolotl/cli/__init__.py | 2 +- .../utils/config/models/input/v0_4_1/__init__.py | 14 ++++++++++++++ tests/e2e/patched/test_unsloth_qlora.py | 1 + 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e8ef86285..d07b10ce3 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): "compute_capability": gpu_version, }, env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], }, ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 24ea62c77..5a4d27118 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1432,6 +1432,20 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def notify_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + LOG.info( + "Unsloth may not be well supported with the latest version of Transformers, " + "resulting in loss that is incorrect." + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_deepspeed(cls, data): diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 8e0d03380..231cbd28e 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true" # pylint: disable=duplicate-code +@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers") class TestUnslothQLoRA: """ Test class for Unsloth QLoRA Llama models