Ensure device mesh patching is applied (#2842)
* move patches; make patch stronger * fix broken tests * guard sequence_parallel_degree comparison against none --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -65,6 +65,7 @@ class PatchManager:
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_gemma3_conditional_generation_forward_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -231,6 +232,17 @@ class PatchManager:
|
||||
|
||||
patch_gemma3_conditional_generation_forward()
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _patch_attention(self):
|
||||
"""Apply attention-specific patches based on model type."""
|
||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||
|
||||
@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raies:
|
||||
Raises:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
"""
|
||||
original_fn = accelerate.data_loader.prepare_data_loader
|
||||
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(accelerate.data_loader):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Create a new function from the patched source
|
||||
namespace = {}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
patched_source, accelerate.data_loader.__dict__, namespace
|
||||
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
patched_source, globals(), namespace
|
||||
)
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
|
||||
accelerate.data_loader.prepare_data_loader = patched_function
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
original_fn.__code__ = patched_function.__code__
|
||||
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||
fsdp: Whether to use FSDP.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming
|
||||
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming.
|
||||
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||
# only use "fsdp" and "cp" for the device mesh.
|
||||
return dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(device_ids).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||
)
|
||||
|
||||
# Replace the original method with our new method
|
||||
|
||||
@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
@@ -238,12 +236,6 @@ class SequenceParallelContextManager:
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
|
||||
@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration():
|
||||
def test_kernel_training_integration(temp_dir):
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Load model
|
||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
assert found_patched_attn
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero():
|
||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user