broken 🦥 with latest transformers
This commit is contained in:
@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1432,6 +1432,20 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def notify_qlora_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
LOG.info(
|
||||
"Unsloth may not be well supported with the latest version of Transformers, "
|
||||
"resulting in loss that is incorrect."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
|
||||
@@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
|
||||
class TestUnslothQLoRA:
|
||||
"""
|
||||
Test class for Unsloth QLoRA Llama models
|
||||
|
||||
Reference in New Issue
Block a user