This commit is contained in:
Dan Saunders
2025-09-24 13:25:46 -04:00
parent 56e0a77e0d
commit 08124a7c92
2 changed files with 5 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
"""Tests for the Trainer context parallel patch."""
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
@@ -33,7 +33,7 @@ def restore_trainer_prepare_method():
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method):
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
@@ -51,7 +51,6 @@ def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
# Original guard should be gone, patched guard should be present
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source