diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3f8116b21..610e87c7b 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -65,6 +65,7 @@ class PatchManager: self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() self._apply_gemma3_conditional_generation_forward_patch() + self._apply_sequence_parallel_patches() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -231,6 +232,17 @@ class PatchManager: patch_gemma3_conditional_generation_forward() + def _apply_sequence_parallel_patches(self): + """Apply sequence parallelism patches.""" + if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: + from axolotl.monkeypatch.ring_attn.patch import ( + patch_prepare_data_loader, + patch_prepare_device_mesh, + ) + + patch_prepare_data_loader() + patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) + def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index d83476e5a..017b420d2 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): def patch_prepare_data_loader(): """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. - Raies: + Raises: RuntimeError: If source code to patch does not exist. """ original_fn = accelerate.data_loader.prepare_data_loader @@ -168,23 +168,34 @@ def patch_prepare_data_loader(): ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE ) + items_to_import = [] + for item in dir(accelerate.data_loader): + if item in patched_source: + items_to_import.append(item) + # Create a new function from the patched source namespace = {} exec( # pylint: disable=exec-used # nosec B102 - patched_source, accelerate.data_loader.__dict__, namespace + f"from accelerate.data_loader import ({', '.join(items_to_import)})", + globals(), + ) + exec( # pylint: disable=exec-used # nosec B102 + patched_source, globals(), namespace ) - patched_function = namespace["prepare_data_loader"] - accelerate.data_loader.prepare_data_loader = patched_function + patched_function = namespace["prepare_data_loader"] + original_fn.__code__ = patched_function.__code__ + LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") -def patch_prepare_device_mesh(sequence_parallel_degree: int): +def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh that includes sequence parallelism with the specified degree. Args: - sequence_parallel_degree (int): The degree of sequence parallelism to use. + sequence_parallel_degree: The degree of sequence parallelism to use. + fsdp: Whether to use FSDP. """ def _prepare_device_mesh(self): @@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int): ) device_ids = list(range(world_size)) - # Note that we use "cp" instead of "sp" to match the PyTorch native "context - # parallelism" implementation naming + # NOTE: We use "cp" instead of "sp" to match the PyTorch native "context + # parallelism" implementation naming. + # NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we + # only use "fsdp" and "cp" for the device mesh. return dist.DeviceMesh( "cuda", torch.tensor(device_ids).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), + mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"), ) # Replace the original method with our new method diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index f429cd2ae..1ac805a73 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -12,8 +12,6 @@ from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, register_ring_attn, update_ring_attn_params, ) @@ -238,12 +236,6 @@ class SequenceParallelContextManager: ring_attn_func=self.ring_attn_func, ) - # Patches for accelerate functionality - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.sequence_parallel_degree - ) - def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 56ce5a8b9..b4dc5de54 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -396,7 +396,7 @@ def test_model_architecture(model_config): # pylint: disable=duplicate-code -def test_kernel_training_integration(): +def test_kernel_training_integration(temp_dir): """Test model loading with kernel patches enabled.""" from axolotl.cli.utils import load_model_and_tokenizer @@ -426,6 +426,14 @@ def test_kernel_training_integration(): } ) + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + # Load model model, _, _ = load_model_and_tokenizer(cfg=cfg) @@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir): assert found_patched_attn -def test_kernel_training_integration_dropout_non_zero(): +def test_kernel_training_integration_dropout_non_zero(temp_dir): """Test model loading with dropout non-zero should not patch.""" from axolotl.cli.utils import load_model_and_tokenizer @@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero(): } ) + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + # Get original attention class attention_cls = get_attention_cls_from_config(cfg)