broken 🦥 with latest transformers
This commit is contained in:
@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1432,6 +1432,20 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def notify_qlora_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Unsloth may not be well supported with the latest version of Transformers, "
|
||||||
|
"resulting in loss that is incorrect."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_deepspeed(cls, data):
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true"
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
|
||||||
class TestUnslothQLoRA:
|
class TestUnslothQLoRA:
|
||||||
"""
|
"""
|
||||||
Test class for Unsloth QLoRA Llama models
|
Test class for Unsloth QLoRA Llama models
|
||||||
|
|||||||
Reference in New Issue
Block a user