From 23389b38b708ac06e35a86bff20000607e90b7dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 10:34:44 -0500 Subject: [PATCH] bump to latest transformers release --- requirements.txt | 3 ++- src/axolotl/utils/models.py | 8 +++----- tests/patched/test_llama_trainer_ga.py | 3 +++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1f7ac7bba..c9778bebf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,8 @@ liger-kernel==0.5.2 packaging==23.2 peft==0.14.0 -transformers==4.47.1 +# transformers==4.48.1 +transformers @ git+https://github.com/huggingface/transformers.git@v4.48-release tokenizers>=0.21.0 accelerate==1.2.1 datasets==3.2.0 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4a665c111..bff72a1d7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -387,15 +387,13 @@ class ModelLoader: self.patch_attention() if self.cfg.model_config_type == "llama": - from axolotl.monkeypatch.trainer_grad_accum import ( + from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,; patch_training_step_for_ga, patch_flash_attention_forward, - patch_forward_for_ga, - patch_training_step_for_ga, ) patch_flash_attention_forward() - patch_forward_for_ga() - patch_training_step_for_ga() + # patch_forward_for_ga() + # patch_training_step_for_ga() if self.cfg.sample_packing and self.cfg.s2_attention: raise ValueError( diff --git a/tests/patched/test_llama_trainer_ga.py b/tests/patched/test_llama_trainer_ga.py index 58c229cf3..8a0b338ae 100644 --- a/tests/patched/test_llama_trainer_ga.py +++ b/tests/patched/test_llama_trainer_ga.py @@ -1,12 +1,15 @@ """"Test module for checking whether the Hugging Face Transformers is working as expected.""" import unittest +import pytest + from axolotl.monkeypatch.trainer_grad_accum import ( check_forward_is_patchable, check_training_step_is_patchable, ) +@pytest.mark.skip("may not be needed for latest transformers version") class TestTrainerGAIntegration(unittest.TestCase): """llama monkeypatch integration tests."""