Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
39ab9626f1 add transformers module to cleanup 2024-12-08 14:52:54 -05:00
Wing Lian
26bd81cec0 re-enable tests w change in patching 2024-12-08 14:52:09 -05:00
Wing Lian
1302e31049 Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
2024-12-08 14:50:40 -05:00
Wing Lian
be5f554a62 bump autoawq to 0.2.7.post3 (#2150) 2024-12-07 22:24:09 -05:00
Wing Lian
22319182ab fix for auto_map check when using remote code and multipack for models like deepseek (#2151) [skip ci] 2024-12-07 22:23:52 -05:00
Wing Lian
440aab8a6f add --version support to axolotl cli (#2152) [skip ci] 2024-12-07 22:23:33 -05:00
Wing Lian
5bef19064b [tests] reset known modules that are patched on each test function end (#2147)
* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
2024-12-07 17:24:46 -05:00
Wing Lian
743ba62bd5 Transformers 4.47.0 (#2138)
* bump transformers and trl

* fix: update trainer.log signature

* fix trl trainer.log interfaces

* broken 🦥 with latest transformers

* skip parent, call grandparent - yeah, super janky

* update HF HUB env var and fix reward trainer log since it doesn't directly override log

* also bump accelerate

* patches for llama ga

* detab the code to check

* fix whitespace for patch check

* play nicely with CI tests since we patch everytime

* fix pop default in case it doesn't exist

* more tweaks to make patches nicer in CI

* fix detab for when there are possibly multiple patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-07 05:03:01 -05:00
Chirag Jain
f9a7748bd8 Fix llama type model check (#2142) [skip ci] 2024-12-07 05:02:32 -05:00
13 changed files with 228 additions and 64 deletions

View File

@@ -2,6 +2,6 @@
set -e set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.0 transformers>=4.46.3
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0 bitsandbytes==0.45.0
accelerate==1.2.0 accelerate==1.2.0
@@ -31,7 +31,7 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post2 autoawq==0.2.7.post3
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.4.2 liger-kernel==0.4.2

View File

@@ -5,6 +5,7 @@ from typing import Optional
import click import click
import axolotl
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -973,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1165,9 +1172,13 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
@@ -1185,9 +1196,13 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
@@ -1232,9 +1247,13 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
@@ -1252,9 +1271,13 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
@@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method # TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
logs, start_time return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
) logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):

View File

@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -3,14 +3,13 @@ fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128 see https://github.com/huggingface/transformers/pull/35128
""" """
import inspect import inspect
import logging
from accelerate.logging import get_logger from transformers import LlamaForCausalLM, Trainer
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum") LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """ ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
@@ -67,7 +66,7 @@ PATCHED_LLAMA_FCLM_CODE = """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch") num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(
@@ -111,12 +110,17 @@ def patch_training_step_for_ga():
monkeypatch for fixing the training loop for gradient accumulation monkeypatch for fixing the training loop for gradient accumulation
""" """
training_step = get_training_step_code() try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step) training_step, _ = detab_code(training_step)
assert ( if ORIGINAL_CONTEXT_CODE not in training_step:
ORIGINAL_CONTEXT_CODE in training_step return
), "Original training_step code not found" # assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace( training_step = training_step.replace(
@@ -140,7 +144,7 @@ def patch_training_step_for_ga():
globals(), globals(),
) )
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True) LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821 _fixed_training_step # pylint: disable=undefined-variable # noqa: F821
) )
@@ -164,10 +168,15 @@ def patch_forward_for_ga():
monkeypatch for fixing the training loop for gradient accumulation monkeypatch for fixing the training loop for gradient accumulation
""" """
forward = get_model_forward_code() try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward) forward, _ = detab_code(forward)
assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found" if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE) forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace( forward = forward.replace(
@@ -191,7 +200,7 @@ def patch_forward_for_ga():
globals(), globals(),
) )
exec(forward, globals()) # pylint: disable=exec-used # nosec B102 exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True) LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821 _fixed_forward # pylint: disable=undefined-variable # noqa: F821
) )

View File

@@ -9,10 +9,7 @@ import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaFlashAttention2
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
return attn_output return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]: def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0) try:
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora(): def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code() self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward self_attn_forward
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
globals(), globals(),
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True) LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = ( LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

View File

@@ -153,7 +153,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
( (
hasattr(model_config, "model_type") hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"] and model_config.model_type in ["llama", "mllama_text_model"]
) )
or cfg.is_llama_derived_model or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower() or "llama" in cfg.base_model.lower()

View File

@@ -1432,20 +1432,6 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def notify_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
LOG.info(
"Unsloth may not be well supported with the latest version of Transformers, "
"resulting in loss that is incorrect."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_torch_compile_deepspeed(cls, data): def check_torch_compile_deepspeed(cls, data):

View File

@@ -380,6 +380,13 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -406,10 +413,14 @@ class ModelLoader:
and self.cfg.flash_attention and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
has_remote_code = ( if "auto_map" in self.model_config:
"auto_map" in self.model_config try:
and "AutoModelForCausalLM" in self.model_config["auto_map"] auto_map_config = self.model_config["auto_map"]
) except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code

View File

@@ -2,7 +2,9 @@
shared pytest fixtures shared pytest fixtures
""" """
import functools import functools
import importlib
import shutil import shutil
import sys
import tempfile import tempfile
import time import time
@@ -113,3 +115,40 @@ def temp_dir():
yield _temp_dir yield _temp_dir
# Clean up the directory after the test # Clean up the directory after the test
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

View File

@@ -20,7 +20,6 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(reason="latest unsloth doesn't work with latest transformers")
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models
@@ -37,6 +36,9 @@ class TestUnslothQLoRA:
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": sample_packing, "sample_packing": sample_packing,
"flash_attention": True, "flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
@@ -83,6 +85,9 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
@@ -134,6 +139,9 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",