From 7669a03fb4cebd02bedcb8a12d10c3ac66ec2fc5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 9 Jan 2025 16:01:59 -0500 Subject: [PATCH] update upstream HF deps (#2239) * bump axolotl contribs for upstream main conflicts: * bump datasets, tokenizer, trl * remove log workarounds in trl * bump lm-eval * remove unsloth_ import from critical path * remove llama fa2 from conftest * unsloth breaks with latest upstream --- requirements.txt | 10 +- src/axolotl/core/trainer_builder.py | 108 +----------------- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 2 +- src/axolotl/monkeypatch/trainer_grad_accum.py | 2 +- src/axolotl/monkeypatch/unsloth_.py | 13 +-- src/axolotl/monkeypatch/utils.py | 12 +- tests/conftest.py | 12 +- tests/e2e/patched/test_unsloth_integration.py | 5 + tests/e2e/patched/test_unsloth_qlora.py | 3 + 9 files changed, 36 insertions(+), 131 deletions(-) diff --git a/requirements.txt b/requirements.txt index 283b5cc2d..550fe6eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,11 +14,11 @@ packaging==23.2 peft==0.14.0 transformers==4.47.1 -tokenizers>=0.20.1 +tokenizers>=0.21.0 accelerate==1.2.1 -datasets==3.1.0 +datasets==3.2.0 deepspeed==0.16.1 -trl==0.12.1 +trl==0.13.0 optimum==1.16.2 hf_transfer @@ -53,7 +53,7 @@ zstandard==0.22.0 fastcore # lm eval harness -lm_eval==0.4.4 +lm_eval==0.4.7 langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 -axolotl-contribs-lgpl==0.0.2 +axolotl-contribs-lgpl==0.0.3 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5cc2b2ea9..176ce4174 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset -from packaging import version from peft.optimizers import create_loraplus_optimizer from torch import nn from torch.optim.lr_scheduler import OneCycleLR @@ -984,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - try: - return super().log(logs, start_time) - except TypeError: - return super().log(logs) # transformers<=4.46 - return super().log(logs) # transformers<=4.46 + return super().log(logs, start_time) def store_metrics( self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" @@ -1173,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): torch.cuda.empty_cache() return loss - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - # TODO remove once trl supports the updated to the Trainer.log method - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super(DPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) - # transformers<=4.46 - return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call - class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ @@ -1197,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): tag_names = ["axolotl", "orpo"] - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - # TODO remove once trl supports the updated to the Trainer.log method - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) - # transformers<=4.46 - return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call - class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ @@ -1221,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): tag_names = ["axolotl", "kto"] - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - # TODO remove once trl supports the updated to the Trainer.log method - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # train metrics should have no prefix, eval should have 'eval_' - prefix = "eval_" if train_eval == "eval" else "" - # accumulate average metrics from sums and lengths - for split in ["chosen", "rejected"]: - if f"count/{split}" in self._stored_metrics[train_eval]: - count_sum = ( - torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]) - .sum() - .item() - ) - for metric in ["rewards", "logps", "logits"]: - logs[f"{prefix}{metric}/{split}"] = ( - torch.Tensor( - self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - ) - .sum() - .item() - / count_sum - ) - # delete obsolete metric - del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - del self._stored_metrics[train_eval][f"count/{split}"] - # calculate reward margin - if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = ( - logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] - ) - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super(KTOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) - # transformers<=4.46 - return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call - class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): """ @@ -1272,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - # TODO remove once trl supports the updated to the Trainer.log method - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super(CPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) - # transformers<=4.46 - return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call - class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): """ @@ -1296,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): tag_names = ["axolotl", "reward"] - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - # TODO remove once trl supports the updated to the Trainer.log method - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super(RewardTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) - # transformers<=4.46 - return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call - class TrainerBuilderBase(abc.ABC): """ diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 185f742d7..00c2dfebc 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -6,7 +6,7 @@ import logging from transformers import Trainer -from axolotl.monkeypatch.unsloth_ import detab_code +from axolotl.monkeypatch.utils import detab_code LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 550f00e30..05d706704 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -8,7 +8,7 @@ import logging from transformers import LlamaForCausalLM, Trainer from transformers.modeling_flash_attention_utils import _flash_attention_forward -from axolotl.monkeypatch.unsloth_ import detab_code +from axolotl.monkeypatch.utils import detab_code LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 21fdb7edf..c81bacbfc 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -1,9 +1,7 @@ """module for patching with unsloth optimizations""" import inspect -import re import types -from typing import Tuple import torch from accelerate.logging import get_logger @@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM from torch import nn from transformers.models.llama.modeling_llama import LlamaFlashAttention2 +from axolotl.monkeypatch.utils import detab_code + LOG = get_logger("axolotl.monkeypatch.unsloth") ORIGINAL_QKV_CODE = """ @@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: raise ValueError("Unsupported model type") -def detab_code(code: str) -> Tuple[str, str]: - try: - spaces = re.match(r"([\s\t]{1,})", code).group(0) - code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) - except AttributeError: - return code, "" - return code, spaces - - self_attn_lora_patched = False # pylint: disable=invalid-name diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index f29f21be7..c2772b471 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -1,7 +1,8 @@ """ Shared utils for the monkeypatches """ -from typing import Optional +import re +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa( mask_2d_to_4d(attention_mask, dtype=dtype), *args, ) + + +def detab_code(code: str) -> Tuple[str, str]: + try: + spaces = re.match(r"([\s\t]{1,})", code).group(0) + code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) + except AttributeError: + return code, "" + return code, spaces diff --git a/tests/conftest.py b/tests/conftest.py index f2519cdcf..85e276722 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -120,13 +120,12 @@ def temp_dir(): @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): from transformers import Trainer - from transformers.models.llama.modeling_llama import ( + from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, LlamaAttention, - LlamaFlashAttention2, LlamaForCausalLM, ) - original_fa2_forward = LlamaFlashAttention2.forward + # original_fa2_forward = LlamaFlashAttention2.forward original_llama_attn_forward = LlamaAttention.forward original_llama_forward = LlamaForCausalLM.forward original_trainer_inner_training_loop = ( @@ -136,7 +135,7 @@ def cleanup_monkeypatches(): # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward - LlamaFlashAttention2.forward = original_fa2_forward + # LlamaFlashAttention2.forward = original_fa2_forward LlamaAttention.forward = original_llama_attn_forward LlamaForCausalLM.forward = original_llama_forward Trainer._inner_training_loop = ( # pylint: disable=protected-access @@ -149,7 +148,10 @@ def cleanup_monkeypatches(): ("transformers.models.llama",), ( "transformers.models.llama.modeling_llama", - ["LlamaFlashAttention2", "LlamaAttention"], + [ + # "LlamaFlashAttention2", + "LlamaAttention", + ], ), ("transformers.trainer",), ("transformers", ["Trainer"]), diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index 888274286..bc6476dab 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -1,9 +1,14 @@ """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" import unittest +import pytest + from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable +@pytest.mark.skip( + reason="Unsloth integration will be broken going into latest transformers" +) class TestUnslothIntegration(unittest.TestCase): """Unsloth monkeypatch integration tests.""" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index b58406185..0c0ee8610 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true" # pylint: disable=duplicate-code +@pytest.mark.skip( + reason="Unsloth integration will be broken going into latest transformers" +) class TestUnslothQLoRA: """ Test class for Unsloth QLoRA Llama models