fix inference when no chat_template is set, fix unsloth dora check (#2092)

* fix inference when no chat_template is set, fix unsloth dora check

* remove old unsloth version check

* update docs on installing unsloth
This commit is contained in:
Wing Lian
2024-11-20 14:07:54 -05:00
committed by GitHub
parent 68a26f1005
commit 2e99bb303e
5 changed files with 46 additions and 24 deletions

View File

@@ -30,7 +30,10 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -199,6 +202,10 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
for module in layer_modules
)
mlp_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
qkv_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
o_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

View File

@@ -7,7 +7,6 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
@@ -1425,21 +1424,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):