Compare commits

...

9 Commits

Author SHA1 Message Date
Wing Lian
39ab9626f1 add transformers module to cleanup 2024-12-08 14:52:54 -05:00
Wing Lian
26bd81cec0 re-enable tests w change in patching 2024-12-08 14:52:09 -05:00
Wing Lian
1302e31049 Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
2024-12-08 14:50:40 -05:00
Wing Lian
be5f554a62 bump autoawq to 0.2.7.post3 (#2150) 2024-12-07 22:24:09 -05:00
Wing Lian
22319182ab fix for auto_map check when using remote code and multipack for models like deepseek (#2151) [skip ci] 2024-12-07 22:23:52 -05:00
Wing Lian
440aab8a6f add --version support to axolotl cli (#2152) [skip ci] 2024-12-07 22:23:33 -05:00
Wing Lian
5bef19064b [tests] reset known modules that are patched on each test function end (#2147)
* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
2024-12-07 17:24:46 -05:00
Wing Lian
743ba62bd5 Transformers 4.47.0 (#2138)
* bump transformers and trl

* fix: update trainer.log signature

* fix trl trainer.log interfaces

* broken 🦥 with latest transformers

* skip parent, call grandparent - yeah, super janky

* update HF HUB env var and fix reward trainer log since it doesn't directly override log

* also bump accelerate

* patches for llama ga

* detab the code to check

* fix whitespace for patch check

* play nicely with CI tests since we patch everytime

* fix pop default in case it doesn't exist

* more tweaks to make patches nicer in CI

* fix detab for when there are possibly multiple patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-07 05:03:01 -05:00
Chirag Jain
f9a7748bd8 Fix llama type model check (#2142) [skip ci] 2024-12-07 05:02:32 -05:00
16 changed files with 520 additions and 27 deletions

View File

@@ -2,6 +2,6 @@
set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.14.0
transformers==4.46.3
transformers>=4.46.3
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.1.0
accelerate==1.2.0
datasets==3.1.0
deepspeed==0.15.4
pydantic==2.6.3
@@ -31,7 +31,7 @@ art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post2
autoawq==0.2.7.post3
triton>=2.3.0
liger-kernel==0.4.2
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.12.0
trl==0.12.1
zstandard==0.22.0
fastcore

View File

@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)

View File

@@ -5,6 +5,7 @@ from typing import Optional
import click
import axolotl
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -957,13 +958,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return res
def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -971,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1155,6 +1164,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1163,6 +1188,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1171,6 +1212,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1179,6 +1263,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1187,6 +1287,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""

View File

@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -0,0 +1,206 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -9,10 +9,7 @@ import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

View File

@@ -153,7 +153,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
and model_config.model_type in ["llama", "mllama_text_model"]
)
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()

View File

@@ -380,12 +380,28 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -397,10 +413,14 @@ class ModelLoader:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and "AutoModelForCausalLM" in self.model_config["auto_map"]
)
if "auto_map" in self.model_config:
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code

View File

@@ -2,7 +2,9 @@
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time
@@ -113,3 +115,40 @@ def temp_dir():
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

View File

@@ -36,6 +36,9 @@ class TestUnslothQLoRA:
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
@@ -82,6 +85,9 @@ class TestUnslothQLoRA:
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
@@ -133,6 +139,9 @@ class TestUnslothQLoRA:
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",

View File

@@ -0,0 +1,25 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_step_is_patchable(),
"HF transformers Trainer.training_step has changed and isn't patchable",
)
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_forward_is_patchable(),
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
)