fix inference when no chat_template is set, fix unsloth dora check (#2092)
* fix inference when no chat_template is set, fix unsloth dora check * remove old unsloth version check * update docs on installing unsloth
This commit is contained in:
@@ -11,12 +11,10 @@ standard industry baselines.
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
The following will install the correct unsloth and extras from source.
|
||||||
to date libraries.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
python scripts/unsloth_install.py | sh
|
||||||
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using unsloth w Axolotl
|
### Using unsloth w Axolotl
|
||||||
|
|||||||
33
scripts/unsloth_install.py
Normal file
33
scripts/unsloth_install.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# noqa
|
||||||
|
# pylint: skip-file
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Install torch via `pip install torch`")
|
||||||
|
from packaging.version import Version as V
|
||||||
|
|
||||||
|
v = V(torch.__version__)
|
||||||
|
cuda = str(torch.version.cuda)
|
||||||
|
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||||
|
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||||
|
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||||
|
if v <= V("2.1.0"):
|
||||||
|
raise RuntimeError(f"Torch = {v} too old!")
|
||||||
|
elif v <= V("2.1.1"):
|
||||||
|
x = "cu{}{}-torch211"
|
||||||
|
elif v <= V("2.1.2"):
|
||||||
|
x = "cu{}{}-torch212"
|
||||||
|
elif v < V("2.3.0"):
|
||||||
|
x = "cu{}{}-torch220"
|
||||||
|
elif v < V("2.4.0"):
|
||||||
|
x = "cu{}{}-torch230"
|
||||||
|
elif v < V("2.5.0"):
|
||||||
|
x = "cu{}{}-torch240"
|
||||||
|
elif v < V("2.6.0"):
|
||||||
|
x = "cu{}{}-torch250"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Torch = {v} too new!")
|
||||||
|
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||||
|
print(
|
||||||
|
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
||||||
|
)
|
||||||
@@ -30,7 +30,10 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.chat_templates import get_chat_template
|
from axolotl.utils.chat_templates import (
|
||||||
|
get_chat_template,
|
||||||
|
get_chat_template_from_config,
|
||||||
|
)
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
@@ -199,6 +202,10 @@ def do_inference(
|
|||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template)
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
elif cfg.datasets[0].type == "chat_template":
|
||||||
|
chat_template_str = get_chat_template_from_config(
|
||||||
|
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
mlp_not_dora = all(
|
mlp_not_dora = all(
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
qkv_not_dora = all(
|
qkv_not_dora = all(
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
o_not_dora = all(
|
o_not_dora = all(
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Module for pydantic models for configuration
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from importlib.metadata import version
|
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -1425,21 +1424,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_unsloth_xformers_version(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("unsloth_lora_mlp")
|
|
||||||
or data.get("unsloth_lora_qkv")
|
|
||||||
or data.get("unsloth_lora_o")
|
|
||||||
):
|
|
||||||
xformers_version = version("xformers")
|
|
||||||
if xformers_version == "0.0.27":
|
|
||||||
raise ValueError(
|
|
||||||
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_deepspeed(cls, data):
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user