Compare commits

...

9 Commits

Author SHA1 Message Date
Wing Lian
39ab9626f1 add transformers module to cleanup 2024-12-08 14:52:54 -05:00
Wing Lian
26bd81cec0 re-enable tests w change in patching 2024-12-08 14:52:09 -05:00
Wing Lian
1302e31049 Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
2024-12-08 14:50:40 -05:00
Wing Lian
be5f554a62 bump autoawq to 0.2.7.post3 (#2150) 2024-12-07 22:24:09 -05:00
Wing Lian
22319182ab fix for auto_map check when using remote code and multipack for models like deepseek (#2151) [skip ci] 2024-12-07 22:23:52 -05:00
Wing Lian
440aab8a6f add --version support to axolotl cli (#2152) [skip ci] 2024-12-07 22:23:33 -05:00
Wing Lian
5bef19064b [tests] reset known modules that are patched on each test function end (#2147)
* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
2024-12-07 17:24:46 -05:00
Wing Lian
743ba62bd5 Transformers 4.47.0 (#2138)
* bump transformers and trl

* fix: update trainer.log signature

* fix trl trainer.log interfaces

* broken 🦥 with latest transformers

* skip parent, call grandparent - yeah, super janky

* update HF HUB env var and fix reward trainer log since it doesn't directly override log

* also bump accelerate

* patches for llama ga

* detab the code to check

* fix whitespace for patch check

* play nicely with CI tests since we patch everytime

* fix pop default in case it doesn't exist

* more tweaks to make patches nicer in CI

* fix detab for when there are possibly multiple patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-07 05:03:01 -05:00
Chirag Jain
f9a7748bd8 Fix llama type model check (#2142) [skip ci] 2024-12-07 05:02:32 -05:00
16 changed files with 520 additions and 27 deletions

View File

@@ -2,6 +2,6 @@
set -e set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1" ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.46.3 transformers>=4.46.3
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0 bitsandbytes==0.45.0
accelerate==1.1.0 accelerate==1.2.0
datasets==3.1.0 datasets==3.1.0
deepspeed==0.15.4 deepspeed==0.15.4
pydantic==2.6.3 pydantic==2.6.3
@@ -31,7 +31,7 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post2 autoawq==0.2.7.post3
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.4.2 liger-kernel==0.4.2
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0 gcsfs>=2024.5.0
# adlfs # adlfs
trl==0.12.0 trl==0.12.1
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
env_capabilities={ env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0] "torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
}, },
) )

View File

@@ -5,6 +5,7 @@ from typing import Optional
import click import click
import axolotl
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -957,13 +958,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return res return res
def log(self, logs: Dict[str, float]) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
""" """
Log `logs` on the various objects watching training, including stored metrics. Log `logs` on the various objects watching training, including stored metrics.
Args: Args:
logs (`Dict[str, float]`): logs (`Dict[str, float]`):
The values to log. The values to log.
start_time (`Optional[float]`):
The start of training.
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
@@ -971,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1155,6 +1164,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
@@ -1163,6 +1188,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
@@ -1171,6 +1212,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
""" """
@@ -1179,6 +1263,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
""" """
@@ -1187,6 +1287,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """

View File

@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -0,0 +1,206 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -9,10 +9,7 @@ import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaFlashAttention2
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
return attn_output return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]: def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0) try:
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora(): def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code() self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward self_attn_forward
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
globals(), globals(),
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True) LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = ( LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

View File

@@ -153,7 +153,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
( (
hasattr(model_config, "model_type") hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"] and model_config.model_type in ["llama", "mllama_text_model"]
) )
or cfg.is_llama_derived_model or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower() or "llama" in cfg.base_model.lower()

View File

@@ -380,12 +380,28 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError( raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \ "Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -397,10 +413,14 @@ class ModelLoader:
and self.cfg.flash_attention and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
has_remote_code = ( if "auto_map" in self.model_config:
"auto_map" in self.model_config try:
and "AutoModelForCausalLM" in self.model_config["auto_map"] auto_map_config = self.model_config["auto_map"]
) except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code

View File

@@ -2,7 +2,9 @@
shared pytest fixtures shared pytest fixtures
""" """
import functools import functools
import importlib
import shutil import shutil
import sys
import tempfile import tempfile
import time import time
@@ -113,3 +115,40 @@ def temp_dir():
yield _temp_dir yield _temp_dir
# Clean up the directory after the test # Clean up the directory after the test
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

View File

@@ -36,6 +36,9 @@ class TestUnslothQLoRA:
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": sample_packing, "sample_packing": sample_packing,
"flash_attention": True, "flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
@@ -82,6 +85,9 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
@@ -133,6 +139,9 @@ class TestUnslothQLoRA:
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",

View File

@@ -0,0 +1,25 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_step_is_patchable(),
"HF transformers Trainer.training_step has changed and isn't patchable",
)
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_forward_is_patchable(),
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
)