Compare commits
9 Commits
transforme
...
e2e-fsdp-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39ab9626f1 | ||
|
|
26bd81cec0 | ||
|
|
1302e31049 | ||
|
|
be5f554a62 | ||
|
|
22319182ab | ||
|
|
440aab8a6f | ||
|
|
5bef19064b | ||
|
|
743ba62bd5 | ||
|
|
f9a7748bd8 |
@@ -2,6 +2,6 @@
|
||||
set -e
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
|
||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.14.0
|
||||
transformers==4.46.3
|
||||
transformers>=4.46.3
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.45.0
|
||||
accelerate==1.1.0
|
||||
accelerate==1.2.0
|
||||
datasets==3.1.0
|
||||
deepspeed==0.15.4
|
||||
pydantic==2.6.3
|
||||
@@ -31,7 +31,7 @@ art
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq==0.2.7.post2
|
||||
autoawq==0.2.7.post3
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.4.2
|
||||
|
||||
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.0
|
||||
trl==0.12.1
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
|
||||
@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from packaging import version
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
@@ -957,13 +958,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
@@ -971,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs)
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
try:
|
||||
return super().log(logs, start_time)
|
||||
except TypeError:
|
||||
return super().log(logs) # transformers<=4.46
|
||||
return super().log(logs) # transformers<=4.46
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
@@ -1155,6 +1164,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
@@ -1163,6 +1188,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
@@ -1171,6 +1212,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# train metrics should have no prefix, eval should have 'eval_'
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
# accumulate average metrics from sums and lengths
|
||||
for split in ["chosen", "rejected"]:
|
||||
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||
count_sum = (
|
||||
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
for metric in ["rewards", "logps", "logits"]:
|
||||
logs[f"{prefix}{metric}/{split}"] = (
|
||||
torch.Tensor(
|
||||
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
/ count_sum
|
||||
)
|
||||
# delete obsolete metric
|
||||
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||
# calculate reward margin
|
||||
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||
logs[f"{prefix}rewards/margins"] = (
|
||||
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||
)
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
@@ -1179,6 +1263,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
@@ -1187,6 +1287,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
|
||||
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
fix for FSDP optimizer save in trainer w 4.47.0
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def get_training_loop_code() -> str:
|
||||
training_loop = inspect.getsource(
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_training_loop_is_patchable() -> bool:
|
||||
training_loop = get_training_loop_code()
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
return ORIGINAL_TRAINER_CODE in training_loop
|
||||
|
||||
|
||||
def patch_training_loop_for_fsdp():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for fsdp with optimizer save
|
||||
"""
|
||||
|
||||
try:
|
||||
training_loop = get_training_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||
training_loop
|
||||
)
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||
return
|
||||
|
||||
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||
training_loop = training_loop.replace(
|
||||
"def _inner_training_loop(",
|
||||
"def _fixed_inner_training_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in training_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
206
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
206
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
fix for FSDP gradient accumulation
|
||||
see https://github.com/huggingface/transformers/pull/35128
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import LlamaForCausalLM, Trainer
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
"""
|
||||
|
||||
PATCHED_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
"""
|
||||
|
||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
"""
|
||||
|
||||
PATCHED_LLAMA_FCLM_CODE = """
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||
"""
|
||||
|
||||
|
||||
def get_training_step_code() -> str:
|
||||
training_step = inspect.getsource(
|
||||
Trainer.training_step # pylint: disable=protected-access
|
||||
)
|
||||
return training_step
|
||||
|
||||
|
||||
def check_training_step_is_patchable() -> bool:
|
||||
training_step = get_training_step_code()
|
||||
training_step, _ = detab_code(training_step)
|
||||
return ORIGINAL_CONTEXT_CODE in training_step
|
||||
|
||||
|
||||
def patch_training_step_for_ga():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for gradient accumulation
|
||||
"""
|
||||
|
||||
try:
|
||||
training_step = get_training_step_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
||||
training_step, _ = detab_code(training_step)
|
||||
if ORIGINAL_CONTEXT_CODE not in training_step:
|
||||
return
|
||||
# assert (
|
||||
# ORIGINAL_CONTEXT_CODE in training_step
|
||||
# ), "Original training_step code not found"
|
||||
|
||||
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||
training_step = training_step.replace(
|
||||
"def training_step(",
|
||||
"def _fixed_training_step(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in training_step:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching training_step")
|
||||
Trainer.training_step = ( # pylint: disable=protected-access
|
||||
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def get_model_forward_code() -> str:
|
||||
forward = inspect.getsource(
|
||||
LlamaForCausalLM.forward # pylint: disable=protected-access
|
||||
)
|
||||
return forward
|
||||
|
||||
|
||||
def check_forward_is_patchable() -> bool:
|
||||
forward = get_model_forward_code()
|
||||
forward, _ = detab_code(forward)
|
||||
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
||||
|
||||
|
||||
def patch_forward_for_ga():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for gradient accumulation
|
||||
"""
|
||||
|
||||
try:
|
||||
forward = get_model_forward_code()
|
||||
except OSError:
|
||||
return
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
||||
return
|
||||
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
||||
|
||||
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def _fixed_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching forward")
|
||||
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -9,10 +9,7 @@ import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
|
||||
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_forward_code() -> str:
|
||||
forward = inspect.getsource(LlamaForCausalLM.forward)
|
||||
return forward
|
||||
|
||||
|
||||
def get_self_attn_code() -> str:
|
||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||
return forward
|
||||
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||
try:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||
except AttributeError:
|
||||
return code, ""
|
||||
return code, spaces
|
||||
|
||||
|
||||
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def patch_self_attn_lora():
|
||||
global self_attn_lora_patched # pylint: disable=global-statement
|
||||
if self_attn_lora_patched:
|
||||
# prevent patching multiple times
|
||||
return
|
||||
self_attn_forward = get_self_attn_code()
|
||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||
self_attn_forward
|
||||
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
self_attn_lora_patched = True
|
||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||
LlamaFlashAttention2.forward = (
|
||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
|
||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
||||
cfg.is_llama_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type == ["llama", "mllama_text_model"]
|
||||
and model_config.model_type in ["llama", "mllama_text_model"]
|
||||
)
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
|
||||
@@ -380,12 +380,28 @@ class ModelLoader:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
if self.cfg.fsdp:
|
||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
patch_training_loop_for_fsdp,
|
||||
)
|
||||
|
||||
patch_training_loop_for_fsdp()
|
||||
|
||||
if self.cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_forward_for_ga,
|
||||
patch_training_step_for_ga,
|
||||
)
|
||||
|
||||
patch_forward_for_ga()
|
||||
patch_training_step_for_ga()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
raise ValueError(
|
||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||
@@ -397,10 +413,14 @@ class ModelLoader:
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
has_remote_code = (
|
||||
"auto_map" in self.model_config
|
||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||
)
|
||||
if "auto_map" in self.model_config:
|
||||
try:
|
||||
auto_map_config = self.model_config["auto_map"]
|
||||
except TypeError:
|
||||
auto_map_config = self.model_config.auto_map
|
||||
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||
else:
|
||||
has_remote_code = False
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
shared pytest fixtures
|
||||
"""
|
||||
import functools
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
@@ -113,3 +115,40 @@ def temp_dir():
|
||||
yield _temp_dir
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
("transformers",),
|
||||
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||
("transformers.trainer", ["Trainer"]),
|
||||
("transformers.loss.loss_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
module_name = module_name_tuple[0]
|
||||
module = importlib.import_module(module_name)
|
||||
sys.modules[module_name] = module
|
||||
importlib.reload(sys.modules[module_name])
|
||||
if len(module_name_tuple) > 1:
|
||||
module_globals = module_name_tuple[1]
|
||||
for module_global in module_globals:
|
||||
globals().pop(module_global, None)
|
||||
|
||||
@@ -36,6 +36,9 @@ class TestUnslothQLoRA:
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
@@ -82,6 +85,9 @@ class TestUnslothQLoRA:
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
@@ -133,6 +139,9 @@ class TestUnslothQLoRA:
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
|
||||
25
tests/patched/test_llama_trainer_ga.py
Normal file
25
tests/patched/test_llama_trainer_ga.py
Normal file
@@ -0,0 +1,25 @@
|
||||
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
check_forward_is_patchable,
|
||||
check_training_step_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerGAIntegration(unittest.TestCase):
|
||||
"""llama monkeypatch integration tests."""
|
||||
|
||||
def test_train_step_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_training_step_is_patchable(),
|
||||
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||
)
|
||||
|
||||
def test_model_forward_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_forward_is_patchable(),
|
||||
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
|
||||
)
|
||||
Reference in New Issue
Block a user