update upstream HF deps (#2239)

* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
This commit is contained in:
Wing Lian
2025-01-09 16:01:59 -05:00
committed by GitHub
parent 6553683170
commit 7669a03fb4
9 changed files with 36 additions and 131 deletions

View File

@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.1 transformers==4.47.1
tokenizers>=0.20.1 tokenizers>=0.21.0
accelerate==1.2.1 accelerate==1.2.1
datasets==3.1.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.12.1 trl==0.13.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore fastcore
# lm eval harness # lm eval harness
lm_eval==0.4.4 lm_eval==0.4.7
langdetect==1.0.9 langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.2 axolotl-contribs-lgpl==0.0.3

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -984,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): return super().log(logs, start_time)
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1173,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
@@ -1197,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
@@ -1221,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
""" """
@@ -1272,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
""" """
@@ -1296,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations""" """module for patching with unsloth optimizations"""
import inspect import inspect
import re
import types import types
from typing import Tuple
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
from typing import Optional import re
from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype), mask_2d_to_4d(attention_mask, dtype=dtype),
*args, *args,
) )
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -120,13 +120,12 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention, LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM, LlamaForCausalLM,
) )
original_fa2_forward = LlamaFlashAttention2.forward # original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = ( original_trainer_inner_training_loop = (
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward # LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
("transformers.models.llama",), ("transformers.models.llama",),
( (
"transformers.models.llama.modeling_llama", "transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"], [
# "LlamaFlashAttention2",
"LlamaAttention",
],
), ),
("transformers.trainer",), ("transformers.trainer",),
("transformers", ["Trainer"]), ("transformers", ["Trainer"]),

View File

@@ -1,9 +1,14 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase): class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""

View File

@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models