add support for CP + torch SDPA
This commit is contained in:
74
tests/loaders/test_patch_manager_cp.py
Normal file
74
tests/loaders/test_patch_manager_cp.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Tests for PatchManager context parallel patch selection."""
|
||||
|
||||
import addict
|
||||
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _stub_transformers_patches(monkeypatch):
|
||||
"""Replace trainer loss patchers with no-ops for isolation."""
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
|
||||
lambda: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
|
||||
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
|
||||
"""When flash attention is enabled, we patch Trainer for CP."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sdp_attention": False,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 1
|
||||
|
||||
|
||||
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
|
||||
"""When only SDPA is requested, we should not patch Trainer."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 0
|
||||
Reference in New Issue
Block a user