Compare commits
1 Commits
fix/issue-
...
fix/cp-was
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
255c5b90ca |
@@ -133,13 +133,6 @@ class PatchManager:
|
|||||||
patch_evaluation_loop()
|
patch_evaluation_loop()
|
||||||
patch_maybe_log_save_evaluate()
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
if self.cfg.context_parallel_size > 1:
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
|
||||||
patch_prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
|
|
||||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches right after model build, before post-load setup."""
|
"""Apply patches right after model build, before post-load setup."""
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ def patch_prepare_cp():
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
def patched_prepare_cp(self, *args):
|
def patched_prepare_cp(self, *args):
|
||||||
if self.parallelism_config.cp_backend == "deepspeed":
|
if self.parallelism_config.cp_backend == "deepspeed":
|
||||||
@@ -95,4 +96,11 @@ def patch_prepare_cp():
|
|||||||
self._cp_context = _noop_cp_context
|
self._cp_context = _noop_cp_context
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def _noop_prepare_context_parallel_inputs(self, model, inputs):
|
||||||
|
return contextlib.nullcontext, inputs
|
||||||
|
|
||||||
|
# prevent double CP partition
|
||||||
Accelerator._prepare_cp = patched_prepare_cp
|
Accelerator._prepare_cp = patched_prepare_cp
|
||||||
|
|
||||||
|
# remove unneeded calculation upstream
|
||||||
|
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
|
||||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_context_parallel_inputs() -> None:
|
|
||||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
|
||||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
|
||||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
|
||||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
|
||||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
if GUARD_PATTERN not in original_source:
|
|
||||||
LOG.warning(
|
|
||||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
|
||||||
"skipping FlashAttention context parallelism patch"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
|
||||||
patched_source, _ = detab_code(patched_source)
|
|
||||||
patched_source = patched_source.replace(
|
|
||||||
"def _prepare_context_parallel_inputs(",
|
|
||||||
"def axolotl_prepare_context_parallel_inputs(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
module_name = Trainer.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# import symbols referenced in the method so exec can succeed
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in patched_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
# Use a separate namespace to capture the exec'd function
|
|
||||||
namespace = {}
|
|
||||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
|
|
||||||
exec(patched_source, namespace)
|
|
||||||
|
|
||||||
# Explicitly get the function from the namespace
|
|
||||||
axolotl_prepare_context_parallel_inputs = namespace[
|
|
||||||
"axolotl_prepare_context_parallel_inputs"
|
|
||||||
]
|
|
||||||
Trainer._original_prepare_context_parallel_inputs = (
|
|
||||||
Trainer._prepare_context_parallel_inputs
|
|
||||||
)
|
|
||||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
|
||||||
LOG.debug(
|
|
||||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
|
||||||
)
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""Tests for the HF Trainer context parallel patch."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
|
||||||
GUARD_PATTERN,
|
|
||||||
PATCHED_GUARD,
|
|
||||||
patch_prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def restore_trainer_prepare_method():
|
|
||||||
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
|
||||||
original_method = getattr(
|
|
||||||
Trainer,
|
|
||||||
"_original_prepare_context_parallel_inputs",
|
|
||||||
Trainer._prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
patched_attr_present = hasattr(
|
|
||||||
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
Trainer._prepare_context_parallel_inputs = original_method
|
|
||||||
if patched_attr_present:
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
|
||||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
|
||||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
|
||||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
|
||||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
|
||||||
# Ensure we start from the unpatched method
|
|
||||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
|
||||||
Trainer._prepare_context_parallel_inputs = (
|
|
||||||
Trainer._original_prepare_context_parallel_inputs
|
|
||||||
)
|
|
||||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
|
||||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
|
|
||||||
patched_method = Trainer._prepare_context_parallel_inputs
|
|
||||||
assert patched_method is not None
|
|
||||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
|
||||||
|
|
||||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
|
||||||
assert GUARD_PATTERN not in source
|
|
||||||
assert PATCHED_GUARD in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
|
||||||
"""Calling the patch twice should leave the same patched function in place."""
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
first_patched = Trainer._prepare_context_parallel_inputs
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
second_patched = Trainer._prepare_context_parallel_inputs
|
|
||||||
|
|
||||||
assert first_patched is second_patched
|
|
||||||
Reference in New Issue
Block a user