Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
9c0fa60220 fsdp2 w evals fixed upstream 2025-08-11 16:26:42 -04:00
Wing Lian
8efdc59796 just assume that fa supports window 2025-08-11 16:09:11 -04:00
Wing Lian
172b08b209 integration check for transformers#40002 2025-08-11 10:06:11 -04:00
3 changed files with 2 additions and 6 deletions

View File

@@ -14,7 +14,7 @@ packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.17.0 peft==0.17.0
transformers==4.55.0 transformers @ git+https://github.com/vasqu/transformers@fix-fa-integration
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.10.0 accelerate==1.10.0
datasets==4.0.0 datasets==4.0.0

View File

@@ -18,9 +18,7 @@ from torch.distributed import DeviceMesh
try: try:
from transformers.modeling_flash_attention_utils import _flash_supports_window from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError: except ImportError:
from transformers.modeling_flash_attention_utils import ( _flash_supports_window = True
_flash_supports_window_size as _flash_supports_window,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger

View File

@@ -3,7 +3,6 @@
import unittest import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import ( from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable, check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable, check_maybe_log_save_evaluate_is_patchable,
) )
@@ -20,7 +19,6 @@ class TestTrainerLossCalc(unittest.TestCase):
the patched code changes upstream. the patched code changes upstream.
""" """
assert check_evaluation_loop_is_patchable() assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable() assert check_maybe_log_save_evaluate_is_patchable()