diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e8ef86285..d07b10ce3 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): "compute_capability": gpu_version, }, env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], }, ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 24ea62c77..5a4d27118 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1432,6 +1432,20 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def notify_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + LOG.info( + "Unsloth may not be well supported with the latest version of Transformers, " + "resulting in loss that is incorrect." + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_deepspeed(cls, data): diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 8e0d03380..231cbd28e 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true" # pylint: disable=duplicate-code +@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers") class TestUnslothQLoRA: """ Test class for Unsloth QLoRA Llama models