Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
31723ac523 fix whitespace for patch check 2024-12-06 16:43:44 -05:00
Wing Lian
2e9e423dfd detab the code to check 2024-12-06 16:42:29 -05:00
Wing Lian
cbe61186dc patches for llama ga 2024-12-06 16:40:24 -05:00
Wing Lian
2a83580bdc also bump accelerate 2024-12-06 15:24:57 -05:00
Wing Lian
825f66b9fd update HF HUB env var and fix reward trainer log since it doesn't directly override log 2024-12-06 14:52:59 -05:00
Wing Lian
3b44989205 skip parent, call grandparent - yeah, super janky 2024-12-06 12:19:14 -05:00
Wing Lian
811224d7b7 broken 🦥 with latest transformers 2024-12-06 11:34:06 -05:00
Wing Lian
84a14fc604 fix trl trainer.log interfaces 2024-12-06 10:35:29 -05:00
NanoCode012
86cf62ca46 fix: update trainer.log signature 2024-12-06 10:27:18 -05:00
Wing Lian
fc54e10455 bump transformers and trl 2024-12-06 10:27:12 -05:00
13 changed files with 64 additions and 228 deletions

View File

@@ -2,6 +2,6 @@
set -e set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers>=4.46.3 transformers==4.47.0
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0 bitsandbytes==0.45.0
accelerate==1.2.0 accelerate==1.2.0
@@ -31,7 +31,7 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post3 autoawq==0.2.7.post2
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.4.2 liger-kernel==0.4.2

View File

@@ -5,7 +5,6 @@ from typing import Optional
import click import click
import axolotl
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -17,7 +16,6 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -974,13 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1172,13 +1165,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
@@ -1196,13 +1185,9 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
@@ -1247,13 +1232,9 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
@@ -1271,13 +1252,9 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
@@ -1289,12 +1266,9 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method # TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call logs, start_time
logs, start_time )
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):

View File

@@ -1,80 +0,0 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -3,13 +3,14 @@ fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128 see https://github.com/huggingface/transformers/pull/35128
""" """
import inspect import inspect
import logging
from transformers import LlamaForCausalLM, Trainer from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """ ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
@@ -66,7 +67,7 @@ PATCHED_LLAMA_FCLM_CODE = """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None) num_items_in_batch = kwargs.pop("num_items_in_batch")
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(
@@ -110,17 +111,12 @@ def patch_training_step_for_ga():
monkeypatch for fixing the training loop for gradient accumulation monkeypatch for fixing the training loop for gradient accumulation
""" """
try: training_step = get_training_step_code()
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step) training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step: assert (
return ORIGINAL_CONTEXT_CODE in training_step
# assert ( ), "Original training_step code not found"
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace( training_step = training_step.replace(
@@ -144,7 +140,7 @@ def patch_training_step_for_ga():
globals(), globals(),
) )
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step") LOG.info("patching training_step", main_process_only=True)
Trainer.training_step = ( # pylint: disable=protected-access Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821 _fixed_training_step # pylint: disable=undefined-variable # noqa: F821
) )
@@ -168,15 +164,10 @@ def patch_forward_for_ga():
monkeypatch for fixing the training loop for gradient accumulation monkeypatch for fixing the training loop for gradient accumulation
""" """
try: forward = get_model_forward_code()
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward) forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward: assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE) forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace( forward = forward.replace(
@@ -200,7 +191,7 @@ def patch_forward_for_ga():
globals(), globals(),
) )
exec(forward, globals()) # pylint: disable=exec-used # nosec B102 exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward") LOG.info("patching forward", main_process_only=True)
LlamaForCausalLM.forward = ( # pylint: disable=protected-access LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821 _fixed_forward # pylint: disable=undefined-variable # noqa: F821
) )

View File

@@ -9,7 +9,10 @@ import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -52,6 +55,11 @@ def original_apply_o(self, hidden_states):
return attn_output return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -94,22 +102,12 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]: def detab_code(code: str) -> Tuple[str, str]:
try: spaces = re.match(r"([\s\t]{1,})", code).group(0)
spaces = re.match(r"([\s\t]{1,})", code).group(0) code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora(): def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code() self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward self_attn_forward
@@ -141,7 +139,6 @@ def patch_self_attn_lora():
globals(), globals(),
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True) LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = ( LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

View File

@@ -153,7 +153,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
( (
hasattr(model_config, "model_type") hasattr(model_config, "model_type")
and model_config.model_type in ["llama", "mllama_text_model"] and model_config.model_type == ["llama", "mllama_text_model"]
) )
or cfg.is_llama_derived_model or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower() or "llama" in cfg.base_model.lower()

View File

@@ -1432,6 +1432,20 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def notify_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
LOG.info(
"Unsloth may not be well supported with the latest version of Transformers, "
"resulting in loss that is incorrect."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_torch_compile_deepspeed(cls, data): def check_torch_compile_deepspeed(cls, data):

View File

@@ -380,13 +380,6 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -413,14 +406,10 @@ class ModelLoader:
and self.cfg.flash_attention and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
if "auto_map" in self.model_config: has_remote_code = (
try: "auto_map" in self.model_config
auto_map_config = self.model_config["auto_map"] and "AutoModelForCausalLM" in self.model_config["auto_map"]
except TypeError: )
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code

View File

@@ -2,9 +2,7 @@
shared pytest fixtures shared pytest fixtures
""" """
import functools import functools
import importlib
import shutil import shutil
import sys
import tempfile import tempfile
import time import time
@@ -115,40 +113,3 @@ def temp_dir():
yield _temp_dir yield _temp_dir
# Clean up the directory after the test # Clean up the directory after the test
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

View File

@@ -20,6 +20,7 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models
@@ -36,9 +37,6 @@ class TestUnslothQLoRA:
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": sample_packing, "sample_packing": sample_packing,
"flash_attention": True, "flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
@@ -85,9 +83,6 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
@@ -139,9 +134,6 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",