This commit is contained in:
Dan Saunders
2025-09-24 13:25:46 -04:00
parent 56e0a77e0d
commit 08124a7c92
2 changed files with 5 additions and 4 deletions

View File

@@ -63,4 +63,6 @@ def patch_prepare_context_parallel_inputs() -> None:
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.info("Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP")
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -1,4 +1,4 @@
"""Tests for the Trainer context parallel patch."""
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
@@ -33,7 +33,7 @@ def restore_trainer_prepare_method():
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method):
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
@@ -51,7 +51,6 @@ def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
# Original guard should be gone, patched guard should be present
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source