112 lines
3.1 KiB
Python
112 lines
3.1 KiB
Python
"""Unit tests for choosing the correct context parallel implementation."""
|
|
|
|
from types import SimpleNamespace
|
|
|
|
from axolotl.train import execute_training
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class DummyTrainer:
|
|
"""Minimal trainer stub to exercise execute_training."""
|
|
|
|
def __init__(self):
|
|
self.model = object()
|
|
self.ref_model = None
|
|
self.accelerator = SimpleNamespace(torch_device_mesh=None)
|
|
self.train_called = False
|
|
|
|
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
|
|
self.train_called = True
|
|
|
|
|
|
class DummyPluginManager:
|
|
"""Minimal plugin manager stub."""
|
|
|
|
@staticmethod
|
|
def post_train(cfg, model): # pylint: disable=unused-argument
|
|
return None
|
|
|
|
|
|
class DummyContext:
|
|
"""Test context manager that records entries/exits."""
|
|
|
|
def __init__(self, recorder, **kwargs):
|
|
recorder.append({"kwargs": kwargs})
|
|
self.recorder = recorder
|
|
|
|
def __enter__(self):
|
|
self.recorder[-1]["entered"] = True
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
|
self.recorder[-1]["exited"] = True
|
|
return False
|
|
|
|
|
|
def _base_cfg(**overrides):
|
|
base = {
|
|
"context_parallel_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"ring_attn_func": None,
|
|
"heads_k_stride": None,
|
|
"rl": None,
|
|
"flash_optimum": False,
|
|
}
|
|
base.update(overrides)
|
|
return DictDefault(base)
|
|
|
|
|
|
def test_execute_training_uses_ring_when_flash(monkeypatch):
|
|
"""FlashAttention CP should engage the custom ring context manager."""
|
|
recorder: list[dict] = []
|
|
|
|
monkeypatch.setattr(
|
|
"axolotl.train.SequenceParallelContextManager",
|
|
lambda **kwargs: DummyContext(recorder, **kwargs),
|
|
)
|
|
monkeypatch.setattr(
|
|
"axolotl.train.PluginManager.get_instance",
|
|
lambda: DummyPluginManager(),
|
|
)
|
|
|
|
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
|
|
trainer = DummyTrainer()
|
|
|
|
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
|
|
|
assert trainer.train_called
|
|
assert len(recorder) == 1
|
|
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
|
|
assert recorder[0].get("entered") is True
|
|
assert recorder[0].get("exited") is True
|
|
|
|
|
|
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
|
|
"""SDPA CP should bypass the ring context manager."""
|
|
invoked = {"count": 0}
|
|
|
|
class NoOpContext:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
|
return False
|
|
|
|
monkeypatch.setattr(
|
|
"axolotl.train.SequenceParallelContextManager",
|
|
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
|
|
or NoOpContext(),
|
|
)
|
|
monkeypatch.setattr(
|
|
"axolotl.train.PluginManager.get_instance",
|
|
lambda: DummyPluginManager(),
|
|
)
|
|
|
|
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
|
|
trainer = DummyTrainer()
|
|
|
|
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
|
|
|
assert trainer.train_called
|
|
assert invoked["count"] == 0
|