Ensure device mesh patching is applied (#2842)

* move patches; make patch stronger

* fix broken tests

* guard sequence_parallel_degree comparison against none

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-06-29 22:16:32 -04:00
committed by GitHub
parent cb811f8bf1
commit 35fdbce102
4 changed files with 52 additions and 19 deletions

View File

@@ -65,6 +65,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch() self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch() self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch() self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
@@ -231,6 +232,17 @@ class PatchManager:
patch_gemma3_conditional_generation_forward() patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _patch_attention(self): def _patch_attention(self):
"""Apply attention-specific patches based on model type.""" """Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader(): def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies: Raises:
RuntimeError: If source code to patch does not exist. RuntimeError: If source code to patch does not exist.
""" """
original_fn = accelerate.data_loader.prepare_data_loader original_fn = accelerate.data_loader.prepare_data_loader
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
) )
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source # Create a new function from the patched source
namespace = {} namespace = {}
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
) )
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int): def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree. that includes sequence parallelism with the specified degree.
Args: Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use. sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
""" """
def _prepare_device_mesh(self): def _prepare_device_mesh(self):
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
) )
device_ids = list(range(world_size)) device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context # NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming # parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh( return dist.DeviceMesh(
"cuda", "cuda",
torch.tensor(device_ids).reshape(mesh_shape), torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"), mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
) )
# Replace the original method with our new method # Replace the original method with our new method

View File

@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import ( from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group, get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn, register_ring_attn,
update_ring_attn_params, update_ring_attn_params,
) )
@@ -238,12 +236,6 @@ class SequenceParallelContextManager:
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self): def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism # Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs): def sequence_parallel_pre_hook(_, args, kwargs):

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(): def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model # Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg) model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero(): def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch.""" """Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
} }
) )
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class # Get original attention class
attention_cls = get_attention_cls_from_config(cfg) attention_cls = get_attention_cls_from_config(cfg)