fix async prefetch with nemogym (#3606)
This commit is contained in:
@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
class TestVllmLoraSyncPatch(unittest.TestCase):
|
||||
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
||||
|
||||
These tests exercise the patch-installation branch in isolation. They build
|
||||
a stub trainer with just enough attributes to look like
|
||||
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
||||
|
||||
Background — there are two correct behaviors and we historically had a bug
|
||||
where both modes used the same one:
|
||||
|
||||
- Async prefetch ON: the BG generation thread can't safely call
|
||||
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
||||
the main thread via ``_maybe_sync_vllm_weights``.
|
||||
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
||||
calls ``sync_weights`` once per step boundary on the main thread. We
|
||||
wire that hook directly to ``_sync_lora_adapter`` because
|
||||
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
||||
|
||||
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
||||
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.vllm_lora_sync = vllm_lora_sync
|
||||
args.async_prefetch = async_prefetch
|
||||
|
||||
class FakeVllmGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
model = MagicMock()
|
||||
|
||||
# Use object.__new__ so we don't run __init__ (which needs a real
|
||||
# model, dataset, etc.). We only need the `_generate_single_turn`
|
||||
# method's patch branch to run, so we set up the minimum state.
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.vllm_generation = FakeVllmGen()
|
||||
trainer._patched_sync_weights = False
|
||||
# Spy on _sync_lora_adapter so we can assert it's the function the
|
||||
# hook delegates to in sync mode.
|
||||
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
||||
trainer._sync_peft_weights_no_merge = MagicMock(
|
||||
name="_sync_peft_weights_no_merge_spy"
|
||||
)
|
||||
return trainer
|
||||
|
||||
@staticmethod
|
||||
def _run_patch_branch(trainer):
|
||||
"""Execute just the sync_weights-patching branch in isolation.
|
||||
|
||||
We can't easily call the real ``_generate_single_turn`` because it
|
||||
does a full vLLM generate. Instead we copy the exact branch out of
|
||||
the source so the test verifies the same logic the trainer runs.
|
||||
"""
|
||||
if not getattr(trainer, "_patched_sync_weights", False):
|
||||
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
||||
if getattr(trainer.args, "vllm_lora_sync", False):
|
||||
if getattr(trainer.args, "async_prefetch", False):
|
||||
trainer.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
sync_helper = trainer._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
trainer._patched_sync_weights = True
|
||||
|
||||
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Trigger the patched hook — it must call _sync_lora_adapter.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Hook must be a no-op so BG-thread generation doesn't fight the
|
||||
# main-thread optimizer step over the model weights.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
||||
"""Installing the patch should not pre-emptively sync."""
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
# _sync_lora_adapter should only be called when the patched hook
|
||||
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_patch_is_idempotent(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
first_hook = trainer.vllm_generation.sync_weights
|
||||
# Second call must not re-patch (otherwise we'd lose the original).
|
||||
self._run_patch_branch(trainer)
|
||||
assert trainer.vllm_generation.sync_weights is first_hook
|
||||
|
||||
|
||||
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
||||
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
||||
|
||||
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
||||
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
||||
is the default for any config that doesn't explicitly set it). We now
|
||||
fall back to interval=1 so unset means "sync every step", matching the
|
||||
behavior of TRL's own ``_generate_single_turn``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(interval, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.async_prefetch = async_prefetch
|
||||
args.vllm_sync_interval = interval
|
||||
args.vllm_lora_sync = True
|
||||
|
||||
class FakeState:
|
||||
global_step = 1
|
||||
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.state = FakeState()
|
||||
trainer._last_synced_step = 0
|
||||
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
||||
return trainer
|
||||
|
||||
def test_interval_none_in_async_mode_does_not_crash(self):
|
||||
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# Should not raise TypeError — defaults to every-step sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_sync_mode_drives_sync(self):
|
||||
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
||||
|
||||
The previous behavior (early return when ``not async_prefetch``)
|
||||
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
||||
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
||||
where the data producer bypasses ``_generate_single_turn``
|
||||
entirely. Without this trigger no sync ever happens and the
|
||||
trainer becomes a no-op.
|
||||
"""
|
||||
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
||||
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
# global_step=4 → 4 % 4 == 0 → sync
|
||||
trainer.state.global_step = 4
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
|
||||
assert cfg.dataloader_num_workers == 0
|
||||
|
||||
|
||||
class TestSelectWeightSyncTransport(unittest.TestCase):
|
||||
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
||||
|
||||
def _caps(self, **kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_native_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_filesystem=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_with_axolotl_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_axolotl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_prefers_nccl(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True, http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_falls_back_to_http(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "http_full"
|
||||
)
|
||||
|
||||
def test_full_param_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps() # all False
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "none"
|
||||
)
|
||||
|
||||
def test_lora_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps()
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "none"
|
||||
)
|
||||
|
||||
|
||||
class TestProbeVllmWeightSync(unittest.TestCase):
|
||||
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
||||
|
||||
def test_stock_vllm_with_lora_enabled(self):
|
||||
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/v1/models": {"get": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
"/v1/unload_lora_adapter": {"post": {}},
|
||||
"/v1/completions": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.nccl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_axolotl_serve_lora_full_capabilities(self):
|
||||
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/set_lora_adapter/": {"post": {}},
|
||||
"/clear_lora_adapter/": {"post": {}},
|
||||
"/http_update_weights/": {"post": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_axolotl is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.http_full is True
|
||||
|
||||
def test_trl_vllm_serve_nccl_only(self):
|
||||
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/close_communicator/": {"post": {}},
|
||||
"/generate/": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_filesystem is False
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_unreachable_server_records_error(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Connection refused")
|
||||
caps = probe_vllm_weight_sync("http://localhost:9999")
|
||||
|
||||
assert caps.probed is False
|
||||
assert caps.probe_error is not None
|
||||
assert "ConnectionError" in caps.probe_error
|
||||
assert caps.nccl is False
|
||||
assert caps.lora_filesystem is False
|
||||
|
||||
|
||||
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
||||
"""End-to-end test of post_trainer_create's transport-selection branch.
|
||||
|
||||
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
||||
leaving the trainer learning in isolation while vLLM kept serving the
|
||||
unmodified base model. After the fix:
|
||||
|
||||
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
||||
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
||||
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
||||
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
||||
- No usable transport → raises ValueError with a precise diagnosis
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_cfg(adapter, vllm_lora_sync):
|
||||
class FakeTRL:
|
||||
pass
|
||||
|
||||
class FakeCfg:
|
||||
pass
|
||||
|
||||
trl = FakeTRL()
|
||||
trl.vllm_lora_sync = vllm_lora_sync
|
||||
trl.vllm_server_host = "127.0.0.1"
|
||||
trl.vllm_server_port = 8000
|
||||
|
||||
cfg = FakeCfg()
|
||||
cfg.nemo_gym_enabled = True
|
||||
cfg.nemo_gym_model_name = None
|
||||
cfg.base_model = "test/model"
|
||||
cfg.nemo_gym_verify_timeout = 30
|
||||
cfg.nemo_gym_multi_turn = True
|
||||
cfg.adapter = adapter
|
||||
cfg.trl = trl
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _fake_trainer():
|
||||
class FakeVLLMGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
|
||||
class FakeTrainer:
|
||||
vllm_generation = FakeVLLMGen()
|
||||
|
||||
return FakeTrainer()
|
||||
|
||||
@staticmethod
|
||||
def _caps(**kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with (
|
||||
patch.object(plugin, "_setup_lora_sync") as setup,
|
||||
patch.object(plugin, "_check_lora_endpoint") as check,
|
||||
patch.object(plugin, "_wire_multi_turn") as wire,
|
||||
):
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
setup.assert_called_once()
|
||||
check.assert_called_once()
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps() # all False, but probed
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "load_lora_adapter" in msg
|
||||
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
||||
|
||||
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(nccl=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with patch.object(plugin, "_wire_multi_turn") as wire:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(http_full=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "HTTP weight sync" in str(ctx.exception)
|
||||
|
||||
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps()
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "init_communicator" in msg
|
||||
assert "http_update_weights" in msg
|
||||
|
||||
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
||||
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
||||
# probed=False, so the error path should mention probing.
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "could not probe" in str(ctx.exception)
|
||||
|
||||
|
||||
class TestNemoGymE2E(unittest.TestCase):
|
||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||
|
||||
@@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
trainer = self._make_mock_trainer()
|
||||
producer._trainer = trainer
|
||||
|
||||
# Mock the prompt iterator (returns a batch of 1 input)
|
||||
producer._prompt_iter = iter(
|
||||
[
|
||||
[
|
||||
{
|
||||
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
||||
}
|
||||
]
|
||||
]
|
||||
)
|
||||
producer._prompt_dl = [
|
||||
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
||||
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
||||
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
||||
# copies of each unique prompt — one entry per rollout.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
# Call produce
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
@@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
producer._request_timeout = 30
|
||||
producer._num_generations = 2
|
||||
producer._trainer = self._make_mock_trainer()
|
||||
producer._prompt_iter = iter(
|
||||
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
)
|
||||
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
# RepeatSampler pre-expands by num_generations=2.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
|
||||
|
||||
@@ -38,6 +38,30 @@ def _reference_norm_noscale(x, eps):
|
||||
return norm(x)
|
||||
|
||||
|
||||
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
|
||||
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
|
||||
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
|
||||
the trailing columns passed through unchanged. Mirrors how Llama-style
|
||||
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4RMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
D = x.shape[-1]
|
||||
n_rot = cos.shape[-1]
|
||||
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||
norm.weight.data.copy_(weight)
|
||||
normed = norm(x)
|
||||
if n_rot == D:
|
||||
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||
x_rot = normed[..., :n_rot]
|
||||
x_pass = normed[..., n_rot:]
|
||||
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
|
||||
return torch.cat([rotated, x_pass], dim=-1)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
(2, 64, 32, 256), # sliding window layer shape
|
||||
@@ -194,6 +218,172 @@ class TestFusedRMSNormRoPEBackward:
|
||||
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||
|
||||
|
||||
class TestFusedRMSNormRoPEPartialRotary:
|
||||
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
|
||||
|
||||
Compares against the original primitives (`Gemma4RMSNorm` +
|
||||
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
|
||||
columns passed through. Without the kernel fix this used to crash with
|
||||
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[
|
||||
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
|
||||
(2, 16, 4, 64, 16), # quarter rotary
|
||||
(2, 32, 8, 128, 64), # half rotary, larger heads
|
||||
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
|
||||
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
|
||||
],
|
||||
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
|
||||
)
|
||||
def test_forward_matches_reference(self, B, S, H, D, n_rot):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||
|
||||
assert y_fused.shape == y_ref.shape == (B, S, H, D)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, (
|
||||
f"partial rotary forward cosine_sim={cos_sim:.6f} "
|
||||
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
|
||||
)
|
||||
|
||||
# The pass-through tail must equal the reference RMSNorm output bit-
|
||||
# for-bit (any deviation would mean the kernel is touching it with a
|
||||
# spurious rotation, which is the original bug class).
|
||||
torch.testing.assert_close(
|
||||
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||
ids=["half_64", "quarter_256"],
|
||||
)
|
||||
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference backward via the original primitives
|
||||
x_ref = x_data.clone().requires_grad_(True)
|
||||
w_ref = weight_init.clone()
|
||||
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
|
||||
y_ref.sum().backward()
|
||||
|
||||
# Fused backward
|
||||
x_fused = x_data.clone().requires_grad_(True)
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||
y_fused.sum().backward()
|
||||
|
||||
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||
ids=["half_64", "quarter_256"],
|
||||
)
|
||||
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
|
||||
# rotary applied to the rotated slice.
|
||||
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||
normed = norm_ref(x_data)
|
||||
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
|
||||
|
||||
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
|
||||
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
|
||||
y_ref.sum().backward()
|
||||
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||
|
||||
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||
w_fused.grad.flatten().float(),
|
||||
norm_ref.weight.grad.flatten().float(),
|
||||
dim=0,
|
||||
)
|
||||
assert cos_sim_w > 0.995, (
|
||||
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
|
||||
)
|
||||
|
||||
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
|
||||
"""Regression: passing cos/sin with shape == head_dim must still
|
||||
match the full-rotary reference (the partial-rotary code path must
|
||||
not perturb the existing full-rotary output)."""
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = 2, 16, 4, 64
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
|
||||
|
||||
def test_validation_errors(self):
|
||||
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
|
||||
Triton crash deeper in the kernel)."""
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = 1, 4, 2, 64
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# n_rot > head_dim
|
||||
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="cannot exceed head_dim"):
|
||||
fused_rms_norm_rope(x, w, cos_big, sin_big)
|
||||
|
||||
# cos/sin last-dim mismatch
|
||||
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="same last dim"):
|
||||
fused_rms_norm_rope(x, w, cos, sin)
|
||||
|
||||
# odd rotary dim
|
||||
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="must be even"):
|
||||
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
|
||||
|
||||
|
||||
class TestFusedRMSNormNoScale:
|
||||
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||
|
||||
|
||||
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for the Gemma 4 fused-attention monkey-patch.
|
||||
|
||||
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
||||
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
||||
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
||||
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
||||
exercised end-to-end (this is the bug originally documented in
|
||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||
|
||||
The full-model forward also pins that the fused forward keeps accepting
|
||||
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
||||
installed transformers version — so any future signature drift on
|
||||
upstream's side trips a clear failure here instead of a confusing
|
||||
TypeError deep in a training run.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
|
||||
]
|
||||
|
||||
pytest.importorskip(
|
||||
"transformers.models.gemma4",
|
||||
reason="fused_attn patch only matters when Gemma 4 is available",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_gemma4_attention():
|
||||
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
|
||||
so the monkey-patch does not leak across the suite."""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
saved = Gemma4TextAttention.forward
|
||||
yield Gemma4TextAttention
|
||||
Gemma4TextAttention.forward = saved
|
||||
|
||||
|
||||
def _build_hybrid_config():
|
||||
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
|
||||
layer with proportional rope and partial_rotary_factor=0.25. This is
|
||||
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
|
||||
enough to fit on any GPU."""
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
global_head_dim=64,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
rope_parameters={
|
||||
"sliding_attention": {
|
||||
"rope_type": "default",
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
"full_attention": {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 1000000.0,
|
||||
"partial_rotary_factor": 0.25,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg._attn_implementation = "sdpa"
|
||||
return cfg
|
||||
|
||||
|
||||
def _build_model(seed=0):
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||
|
||||
torch.manual_seed(seed)
|
||||
cfg = _build_hybrid_config()
|
||||
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
|
||||
|
||||
|
||||
class TestFusedAttnSignature:
|
||||
"""The fused forward must accept the same call shape as
|
||||
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
||||
version. Any signature drift surfaces here as a TypeError."""
|
||||
|
||||
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
||||
"""Run a model forward that exercises the real
|
||||
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
||||
the fused patch installed."""
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model()
|
||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||
|
||||
assert out.shape == (2, 16, 64)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
class TestFusedAttnPerLayerCorrectness:
|
||||
"""Compare the patched attention layer to the stock implementation
|
||||
on a single forward call. This isolates the fused kernel correctness
|
||||
from cross-layer numerical drift."""
|
||||
|
||||
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
|
||||
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
|
||||
installed) for one layer and return the output."""
|
||||
attn = model.layers[layer_idx].self_attn
|
||||
layer_type = model.config.layer_types[layer_idx]
|
||||
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
|
||||
out, _ = attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=(cos, sin),
|
||||
attention_mask=None,
|
||||
shared_kv_states={},
|
||||
)
|
||||
return out
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_idx",
|
||||
[0, 1],
|
||||
ids=["sliding_head32", "global_head64_proportional"],
|
||||
)
|
||||
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=1)
|
||||
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
|
||||
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = self._run_attention(m, layer_idx, hs, pos)
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
got = self._run_attention(m, layer_idx, hs, pos)
|
||||
|
||||
assert got.shape == ref.shape
|
||||
assert torch.isfinite(got).all()
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
ref.flatten().float(), got.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, (
|
||||
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
|
||||
)
|
||||
# bf16 precision: a few millis of absolute drift per element is
|
||||
# acceptable for a Q/K/V projection pipeline. Anything larger is
|
||||
# a real bug.
|
||||
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
|
||||
|
||||
|
||||
class TestFusedAttnFullModel:
|
||||
"""End-to-end model forward + backward through both layer types."""
|
||||
|
||||
def test_full_forward_matches_stock(self, restore_gemma4_attention):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=2)
|
||||
ids = torch.randint(0, 128, (2, 32), device="cuda")
|
||||
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||
|
||||
assert got.shape == ref.shape
|
||||
assert torch.isfinite(got).all()
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
ref.flatten().float(), got.flatten().float(), dim=0
|
||||
)
|
||||
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
|
||||
# accumulates a small amount of numerical drift; we just want to
|
||||
# pin that the two paths are computing the same function.
|
||||
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
|
||||
|
||||
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
|
||||
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
|
||||
for both the sliding and proportional-rope layers."""
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=3).train()
|
||||
patch_gemma4_fused_attn()
|
||||
|
||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||
out.sum().backward()
|
||||
|
||||
# Both layers must accumulate gradients on q_norm.weight and
|
||||
# k_norm.weight — that proves the fused kernel ran the backward.
|
||||
for i, layer in enumerate(m.layers[:2]):
|
||||
attn = layer.self_attn
|
||||
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
|
||||
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
|
||||
assert attn.q_norm.weight.grad.isfinite().all()
|
||||
assert attn.k_norm.weight.grad.isfinite().all()
|
||||
assert attn.q_norm.weight.grad.abs().sum() > 0
|
||||
assert attn.k_norm.weight.grad.abs().sum() > 0
|
||||
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tests for the Gemma 4 hybrid-attention mask fix.
|
||||
|
||||
These tests pin the single critical behavior: after installing the patch,
|
||||
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
|
||||
the underlying mask builder regardless of what the caller's config says.
|
||||
This is what keeps full-attention (head_dim=512) global layers from
|
||||
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
|
||||
the 2D FA2 mask that would be built from the model-level config.
|
||||
|
||||
The tests use a mocked ``create_causal_mask`` so they don't have to load
|
||||
a real 26B Gemma 4 model or even have access to its weights. What matters
|
||||
for the bug fix is which config is handed to the mask factory, not the
|
||||
factory's actual output.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip(
|
||||
"transformers.models.gemma4",
|
||||
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_gemma4_module():
|
||||
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
|
||||
each test so patch state doesn't leak across the suite."""
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
|
||||
saved = modeling_gemma4.create_causal_mask
|
||||
yield modeling_gemma4
|
||||
modeling_gemma4.create_causal_mask = saved
|
||||
# Reset the module-level flag so the next test can re-install cleanly.
|
||||
from axolotl.monkeypatch import gemma4_hybrid_mask
|
||||
|
||||
gemma4_hybrid_mask._PATCH_APPLIED = False
|
||||
|
||||
|
||||
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
original = modeling_gemma4.create_causal_mask
|
||||
assert patch_gemma4_hybrid_mask() is True
|
||||
|
||||
assert modeling_gemma4.create_causal_mask is not original
|
||||
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
|
||||
"patched wrapper must expose the original reference for teardown"
|
||||
)
|
||||
|
||||
|
||||
def test_patch_is_idempotent(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
wrapper_first = modeling_gemma4.create_causal_mask
|
||||
|
||||
# Second call must not re-wrap the already-wrapped function (which
|
||||
# would leak the original reference through a chain of wrappers).
|
||||
patch_gemma4_hybrid_mask()
|
||||
wrapper_second = modeling_gemma4.create_causal_mask
|
||||
|
||||
assert wrapper_first is wrapper_second
|
||||
|
||||
|
||||
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
|
||||
"""Core invariant: when the patched wrapper is called with a config
|
||||
that says ``flash_attention_2``, the underlying mask factory receives
|
||||
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
|
||||
|
||||
Without this, the full-attention global layers get a 2D FA2 mask and
|
||||
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
|
||||
"""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
# Swap in a mock BEFORE installing the patch so the wrapper captures
|
||||
# it as the "original". The mock records every call so we can inspect
|
||||
# what config got passed through.
|
||||
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
|
||||
modeling_gemma4.create_causal_mask = mock_factory
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
# Caller-supplied config says FA2 (that's the model-level setting).
|
||||
caller_config = SimpleNamespace(
|
||||
_attn_implementation="flash_attention_2",
|
||||
head_dim=512,
|
||||
some_other_attr="preserved",
|
||||
)
|
||||
result = modeling_gemma4.create_causal_mask(
|
||||
caller_config,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=None,
|
||||
)
|
||||
|
||||
# Wrapper returned whatever the mock returned — no transformation of
|
||||
# the result itself.
|
||||
assert result == "mask_4d"
|
||||
|
||||
# The mock was called exactly once with a config whose
|
||||
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
|
||||
assert mock_factory.call_count == 1
|
||||
(passed_config, *_), passed_kwargs = mock_factory.call_args
|
||||
assert passed_config._attn_implementation == "sdpa"
|
||||
|
||||
# The wrapper must NOT mutate the caller's config in place — other
|
||||
# mask builders (e.g. create_sliding_window_causal_mask) read from
|
||||
# the same config and must still see fa2.
|
||||
assert caller_config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Other attributes on the config must be preserved so the underlying
|
||||
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
|
||||
assert passed_config.head_dim == 512
|
||||
assert passed_config.some_other_attr == "preserved"
|
||||
|
||||
|
||||
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
|
||||
"""The wrapper must forward positional + keyword args to the original
|
||||
unchanged, so transformers' own call-site in Gemma4TextModel.forward
|
||||
keeps working across minor transformers-version signature drift."""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
mock_factory = MagicMock(return_value="mask")
|
||||
modeling_gemma4.create_causal_mask = mock_factory
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
|
||||
modeling_gemma4.create_causal_mask(
|
||||
caller_config,
|
||||
"positional_arg",
|
||||
inputs_embeds="embeds",
|
||||
attention_mask="mask_2d",
|
||||
past_key_values="cache",
|
||||
position_ids="positions",
|
||||
or_mask_function="or_fn",
|
||||
)
|
||||
|
||||
args, kwargs = mock_factory.call_args
|
||||
# First positional (after config override) is preserved.
|
||||
assert args[1] == "positional_arg"
|
||||
# All kwargs are forwarded untouched.
|
||||
assert kwargs["inputs_embeds"] == "embeds"
|
||||
assert kwargs["attention_mask"] == "mask_2d"
|
||||
assert kwargs["past_key_values"] == "cache"
|
||||
assert kwargs["position_ids"] == "positions"
|
||||
assert kwargs["or_mask_function"] == "or_fn"
|
||||
|
||||
|
||||
def test_unpatch_restores_original(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import (
|
||||
patch_gemma4_hybrid_mask,
|
||||
unpatch_gemma4_hybrid_mask,
|
||||
)
|
||||
|
||||
sentinel = MagicMock(name="original")
|
||||
modeling_gemma4.create_causal_mask = sentinel
|
||||
patch_gemma4_hybrid_mask()
|
||||
assert modeling_gemma4.create_causal_mask is not sentinel
|
||||
|
||||
unpatch_gemma4_hybrid_mask()
|
||||
assert modeling_gemma4.create_causal_mask is sentinel
|
||||
|
||||
|
||||
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
|
||||
|
||||
# Should be a no-op, no exception.
|
||||
unpatch_gemma4_hybrid_mask()
|
||||
|
||||
|
||||
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
|
||||
"""Only ``create_causal_mask`` is overridden — the sliding-window
|
||||
factory must remain bound to its original to preserve FA2 masks for
|
||||
the sliding-attention layers. If we accidentally patch both, the
|
||||
sliding layers get SDPA format and lose the FA2 speedup."""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
|
||||
pytest.skip("transformers version has no create_sliding_window_causal_mask")
|
||||
|
||||
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
|
||||
patch_gemma4_hybrid_mask()
|
||||
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
|
||||
assert sliding_after is sliding_before
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
|
||||
#
|
||||
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
|
||||
# model with 2 layers (one sliding, one full_attention), apply the hybrid
|
||||
# attention path end-to-end, and run a forward pass with a padded
|
||||
# attention_mask at a long-ish seq len. The invariant we're pinning is that
|
||||
# the full_attention layer does not crash with the
|
||||
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
|
||||
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
|
||||
# tokens in the FSDP2 training run.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_tiny_gemma4_text_model():
|
||||
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
|
||||
import torch
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
)
|
||||
# Caller-supplied attn impl simulates the pilot config (fa2 at model
|
||||
# level). The hybrid patch is what makes this survive long context.
|
||||
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
|
||||
torch.manual_seed(42)
|
||||
model = Gemma4TextModel(cfg).eval()
|
||||
return model, cfg
|
||||
|
||||
|
||||
def _apply_hybrid_attn_inline(model, cfg):
|
||||
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
|
||||
to a model, without needing a full PatchManager / pydantic cfg."""
|
||||
import copy
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
for layer_idx, layer in enumerate(model.layers):
|
||||
if cfg.layer_types[layer_idx] != "sliding_attention":
|
||||
attn = getattr(layer, "self_attn", None)
|
||||
if attn is not None and hasattr(attn, "config"):
|
||||
sdpa_cfg = copy.copy(attn.config)
|
||||
sdpa_cfg._attn_implementation = "sdpa"
|
||||
attn.config = sdpa_cfg
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
|
||||
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
|
||||
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
|
||||
Gemma4TextModel runs a forward at long context (1024 tokens) with
|
||||
real padding in the attention mask, producing the expected output
|
||||
shape. This exercises the actual code path that crashed the pilot
|
||||
without needing a real 26B checkpoint or CUDA."""
|
||||
import torch
|
||||
|
||||
model, cfg = _build_tiny_gemma4_text_model()
|
||||
_apply_hybrid_attn_inline(model, cfg)
|
||||
|
||||
B, S = 2, 1024
|
||||
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
|
||||
attn_mask = torch.ones(B, S, dtype=torch.long)
|
||||
# Pad positions in the second row. Without padding, SDPA falls back to
|
||||
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
|
||||
# mask to exercise the actual bug site.
|
||||
attn_mask[1, S // 2 :] = 0
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
|
||||
assert torch.isfinite(out.last_hidden_state).all()
|
||||
|
||||
|
||||
def test_patched_create_causal_mask_returns_4d_for_real_config(
|
||||
restore_gemma4_module,
|
||||
):
|
||||
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
|
||||
and verify the returned mask is a 4D tensor — which is the shape the
|
||||
SDPA-patched global layers need. Without the patch and with a
|
||||
caller-supplied FA2 config this would return a 2D mask and the layer
|
||||
would crash at long context."""
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
)
|
||||
# Simulate the pilot: caller says flash_attention_2, but global layers
|
||||
# were switched to SDPA per-layer. Without the patch, create_causal_mask
|
||||
# would return an FA2 2D mask here and the SDPA layer would crash.
|
||||
cfg._attn_implementation = "flash_attention_2"
|
||||
|
||||
B, S = 2, 1024
|
||||
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
|
||||
attention_mask = torch.ones((B, S), dtype=torch.long)
|
||||
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
|
||||
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
|
||||
past_key_values = DynamicCache(config=cfg)
|
||||
|
||||
mask = modeling_gemma4.create_causal_mask(
|
||||
config=cfg,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
assert mask is not None
|
||||
assert isinstance(mask, torch.Tensor)
|
||||
assert mask.dim() == 4, (
|
||||
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
|
||||
f"shape={tuple(mask.shape)}. The full_attention global layers need "
|
||||
"this shape or they crash at long context."
|
||||
)
|
||||
assert mask.shape[0] == B
|
||||
assert mask.shape[-1] == S
|
||||
assert mask.shape[-2] == S
|
||||
|
||||
# Caller's config must be untouched — other code paths still read it.
|
||||
assert cfg._attn_implementation == "flash_attention_2"
|
||||
@@ -5,6 +5,8 @@ Covers:
|
||||
- save_strategy: 'best' requires metric_for_best_model
|
||||
- streaming=True with val_set_size > 0 is rejected
|
||||
- lora_target_modules with invalid regex patterns is rejected
|
||||
- GRPO: generation batch size must be divisible by num_generations,
|
||||
num_generations >= 2, and effective_gbs >= num_generations * world_size
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
|
||||
)
|
||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestGRPOBatchSizeValidator:
|
||||
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
|
||||
|
||||
These call the @model_validator(mode="before") classmethod directly on a
|
||||
plain dict — same input shape it receives during full Pydantic validation,
|
||||
just without dragging in unrelated fields (datasets / model loading / etc.)
|
||||
that aren't relevant to what's under test. The validator is registered on
|
||||
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
|
||||
same code path ``axolotl train`` exercises.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check(data):
|
||||
from axolotl.utils.schemas.validation import RLValidationMixin
|
||||
|
||||
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
|
||||
|
||||
def test_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
# Should return data unchanged (no exception)
|
||||
out = self._check(data)
|
||||
assert out["trl"]["num_generations"] == 4
|
||||
|
||||
def test_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="num_generations"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_divisible_error_includes_fix_hint(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
|
||||
self._check(data)
|
||||
|
||||
def test_num_generations_one_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 1},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"num_generations >= 2"):
|
||||
self._check(data)
|
||||
|
||||
def test_explicit_generation_batch_size_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 8},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["trl"]["generation_batch_size"] == 8
|
||||
|
||||
def test_explicit_generation_batch_size_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 6},
|
||||
}
|
||||
with pytest.raises(ValueError, match="trl.generation_batch_size"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_grpo_skips_check(self):
|
||||
# Anything other than rl=grpo should pass through untouched, even
|
||||
# with non-divisible batch sizes — they're irrelevant to other RL
|
||||
# methods that don't use group-relative advantages.
|
||||
data = {
|
||||
"rl": "dpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_no_rl_set_skips_check(self):
|
||||
data = {
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_grpo_without_num_generations_skips_check(self):
|
||||
# If num_generations isn't set, TRL uses its own default — we don't
|
||||
# have enough info to validate, so the validator must short-circuit
|
||||
# rather than guess.
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["rl"] == "grpo"
|
||||
|
||||
def test_multi_rank_group_size_check(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4, # gbs=4
|
||||
"world_size": 2, # need gbs >= 4*2 = 8
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"world_size=2"):
|
||||
self._check(data)
|
||||
|
||||
def test_multi_rank_group_size_satisfied(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
|
||||
"world_size": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["gradient_accumulation_steps"] == 8
|
||||
|
||||
Reference in New Issue
Block a user