broken 🦥 with latest transformers

This commit is contained in:
Wing Lian
2024-12-06 11:34:06 -05:00
parent 84a14fc604
commit 811224d7b7
3 changed files with 16 additions and 1 deletions

View File

@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
env_capabilities={ env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0] "torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
}, },
) )

View File

@@ -1432,6 +1432,20 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def notify_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
LOG.info(
"Unsloth may not be well supported with the latest version of Transformers, "
"resulting in loss that is incorrect."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_torch_compile_deepspeed(cls, data): def check_torch_compile_deepspeed(cls, data):

View File

@@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models