nits
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Tests for the Trainer context parallel patch."""
|
||||
"""Tests for the HF Trainer context parallel patch."""
|
||||
|
||||
import pytest
|
||||
from transformers import Trainer
|
||||
@@ -33,7 +33,7 @@ def restore_trainer_prepare_method():
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||
|
||||
|
||||
def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method):
|
||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||
# Ensure we start from the unpatched method
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
@@ -51,7 +51,6 @@ def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method
|
||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||
|
||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||
# Original guard should be gone, patched guard should be present
|
||||
assert GUARD_PATTERN not in source
|
||||
assert PATCHED_GUARD in source
|
||||
|
||||
|
||||
Reference in New Issue
Block a user