Ensure device mesh patching is applied (#2842)

* move patches; make patch stronger

* fix broken tests

* guard sequence_parallel_degree comparison against none

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-06-29 22:16:32 -04:00
committed by GitHub
parent cb811f8bf1
commit 35fdbce102
4 changed files with 52 additions and 19 deletions

View File

@@ -65,6 +65,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -231,6 +232,17 @@ class PatchManager:
patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies:
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""
def _prepare_device_mesh(self):
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
)
device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)
# Replace the original method with our new method

View File

@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
@@ -238,12 +236,6 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero():
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)