use nanmean for loss aggregation (CP fix) (#3033)

* use nanmena for loss aggregation (CP fix)

* use regular asserts

* small changes to make tests isolate

* combining evaluation_loop patches

* fix

* delete unused

* fix check
This commit is contained in:
Dan Saunders
2025-08-08 08:15:17 -04:00
committed by GitHub
parent 2974670bf8
commit 0ae06d756d
6 changed files with 207 additions and 83 deletions

View File

@@ -1,5 +1,5 @@
"""Model loader class implementation for loading, configuring, and patching various
models.
"""
Model loader class implementation for loading, configuring, and patching various models.
"""
import gc

View File

@@ -76,8 +76,20 @@ class PatchManager:
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
)
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_prepare_from_posids()
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""

View File

@@ -1,78 +0,0 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -0,0 +1,165 @@
"""
Module for patching transformers Trainer loss calculation to use nanmean.
This is needed for context parallelism since chunks of the input sequences may be fully
masked and return NaNs in the loss calculation.
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
the other evaluation_loop patch because we can't patch the same code twice without
raising an OSError.
"""
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
}
PATCHED_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_FSDP2_CODE = """
model.eval()
"""
PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
def check_evaluation_loop_is_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
# pylint: disable=protected-access
def patch_evaluation_loop(patch_fsdp2: bool):
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
try:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
except OSError:
return
Trainer.evaluation = evaluation_loop_source
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
# Apply the nanmean patches
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
)
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)
# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",
"def axolotl_evaluation_loop(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in evaluation_loop_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = (
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)
def check_maybe_log_save_evaluate_is_patchable() -> bool:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
return ORIGINAL_MAYBE_CODE in maybe_log_source
# pylint: disable=protected-access
def patch_maybe_log_save_evaluate():
"""Patch the _maybe_log_save_evaluate method."""
# Check if already patched
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
LOG.info("Trainer._maybe_log_save_evaluate already patched")
return
# Check if the patterns exist
try:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
except OSError:
return
Trainer._original_maybe_log_save_evaluate = maybe_log_source
maybe_log_source, _ = detab_code(maybe_log_source)
# Apply the patch
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
# Rename the function to avoid conflicts
maybe_log_source = maybe_log_source.replace(
"def _maybe_log_save_evaluate(",
"def axolotl_maybe_log_save_evaluate(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in maybe_log_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821

View File

@@ -15,7 +15,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
@@ -687,8 +686,6 @@ def setup_trainer(
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
patch_evaluation_loop_for_fsdp2()
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref

View File

@@ -0,0 +1,28 @@
"""Unit tests for trainer loss calc monkeypatch."""
import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)
class TestTrainerLossCalc(unittest.TestCase):
"""
Unit test class for trainer loss calc monkeypatch
"""
def test_trainer_loss_calc_is_patchable(self):
"""
Test that the upstream transformers code is still patchable. This will fail if
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()
if __name__ == "__main__":
unittest.main()