From 08124a7c926ee13db55d810763155db5f20bec5c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 24 Sep 2025 13:25:46 -0400 Subject: [PATCH] nits --- .../monkeypatch/transformers/trainer_context_parallel.py | 4 +++- tests/monkeypatch/test_trainer_context_parallel_patch.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py index 9b0c241d5..74a35e83f 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py +++ b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py @@ -63,4 +63,6 @@ def patch_prepare_context_parallel_inputs() -> None: Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source Trainer._axolotl_prepare_context_parallel_inputs_patched = True - LOG.info("Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP") + LOG.debug( + "Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP" + ) diff --git a/tests/monkeypatch/test_trainer_context_parallel_patch.py b/tests/monkeypatch/test_trainer_context_parallel_patch.py index a188f376c..84c883e91 100644 --- a/tests/monkeypatch/test_trainer_context_parallel_patch.py +++ b/tests/monkeypatch/test_trainer_context_parallel_patch.py @@ -1,4 +1,4 @@ -"""Tests for the Trainer context parallel patch.""" +"""Tests for the HF Trainer context parallel patch.""" import pytest from transformers import Trainer @@ -33,7 +33,7 @@ def restore_trainer_prepare_method(): delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source") -def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method): +def test_patch_attention_guard(restore_trainer_prepare_method): """Patch should swap the guard to allow sdpa or flash attention.""" # Ensure we start from the unpatched method if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): @@ -51,7 +51,6 @@ def test_patch_replaces_guard_for_flash_attention(restore_trainer_prepare_method assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False) source = Trainer._axolotl_prepare_context_parallel_inputs_source - # Original guard should be gone, patched guard should be present assert GUARD_PATTERN not in source assert PATCHED_GUARD in source