broken 🦥 with latest transformers

This commit is contained in:
Wing Lian
2024-12-06 11:34:06 -05:00
parent 84a14fc604
commit 811224d7b7
3 changed files with 16 additions and 1 deletions

View File

@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)

View File

@@ -1432,6 +1432,20 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def notify_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
LOG.info(
"Unsloth may not be well supported with the latest version of Transformers, "
"resulting in loss that is incorrect."
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):

View File

@@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models