fix inference when no chat_template is set, fix unsloth dora check (#2092)
* fix inference when no chat_template is set, fix unsloth dora check * remove old unsloth version check * update docs on installing unsloth
This commit is contained in:
@@ -11,12 +11,10 @@ standard industry baselines.
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||
to date libraries.
|
||||
The following will install the correct unsloth and extras from source.
|
||||
|
||||
```bash
|
||||
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Using unsloth w Axolotl
|
||||
|
||||
33
scripts/unsloth_install.py
Normal file
33
scripts/unsloth_install.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# noqa
|
||||
# pylint: skip-file
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Install torch via `pip install torch`")
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
raise RuntimeError(f"Torch = {v} too old!")
|
||||
elif v <= V("2.1.1"):
|
||||
x = "cu{}{}-torch211"
|
||||
elif v <= V("2.1.2"):
|
||||
x = "cu{}{}-torch212"
|
||||
elif v < V("2.3.0"):
|
||||
x = "cu{}{}-torch220"
|
||||
elif v < V("2.4.0"):
|
||||
x = "cu{}{}-torch230"
|
||||
elif v < V("2.5.0"):
|
||||
x = "cu{}{}-torch240"
|
||||
elif v < V("2.6.0"):
|
||||
x = "cu{}{}-torch250"
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
||||
)
|
||||
@@ -30,7 +30,10 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -199,6 +202,10 @@ def do_inference(
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ Module for pydantic models for configuration
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
@@ -1425,21 +1424,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_unsloth_xformers_version(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
xformers_version = version("xformers")
|
||||
if xformers_version == "0.0.27":
|
||||
raise ValueError(
|
||||
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user