nits
This commit is contained in:
@@ -63,4 +63,6 @@ def patch_prepare_context_parallel_inputs() -> None:
|
|||||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||||
LOG.info("Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP")
|
LOG.debug(
|
||||||
|
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Tests for the Trainer context parallel patch."""
|
"""Tests for the HF Trainer context parallel patch."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@@ -33,7 +33,7 @@ def restore_trainer_prepare_method():
|
|||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||||
|
|
||||||
|
|
||||||
def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method):
|
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||||
# Ensure we start from the unpatched method
|
# Ensure we start from the unpatched method
|
||||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||||
@@ -51,7 +51,6 @@ def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method
|
|||||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||||
|
|
||||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||||
# Original guard should be gone, patched guard should be present
|
|
||||||
assert GUARD_PATTERN not in source
|
assert GUARD_PATTERN not in source
|
||||||
assert PATCHED_GUARD in source
|
assert PATCHED_GUARD in source
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user