fix inference when no chat_template is set, fix unsloth dora check (#2092)
* fix inference when no chat_template is set, fix unsloth dora check * remove old unsloth version check * update docs on installing unsloth
This commit is contained in:
@@ -30,7 +30,10 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -199,6 +202,10 @@ def do_inference(
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ Module for pydantic models for configuration
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
@@ -1425,21 +1424,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_unsloth_xformers_version(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
xformers_version = version("xformers")
|
||||
if xformers_version == "0.0.27":
|
||||
raise ValueError(
|
||||
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user