From 0ae06d756dcb24d8493bd59290fac3ec06ea27a3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 8 Aug 2025 08:15:17 -0400 Subject: [PATCH] use nanmean for loss aggregation (CP fix) (#3033) * use nanmena for loss aggregation (CP fix) * use regular asserts * small changes to make tests isolate * combining evaluation_loop patches * fix * delete unused * fix check --- src/axolotl/loaders/model.py | 4 +- src/axolotl/loaders/patch_manager.py | 12 ++ src/axolotl/monkeypatch/trainer_eval_guard.py | 78 --------- .../transformers/trainer_loss_calc.py | 165 ++++++++++++++++++ src/axolotl/utils/trainer.py | 3 - tests/monkeypatch/test_trainer_loss_calc.py | 28 +++ 6 files changed, 207 insertions(+), 83 deletions(-) delete mode 100644 src/axolotl/monkeypatch/trainer_eval_guard.py create mode 100644 src/axolotl/monkeypatch/transformers/trainer_loss_calc.py create mode 100644 tests/monkeypatch/test_trainer_loss_calc.py diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 95a56b326..6bf1f149b 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -1,5 +1,5 @@ -"""Model loader class implementation for loading, configuring, and patching various -models. +""" +Model loader class implementation for loading, configuring, and patching various models. """ import gc diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 047eb20fd..795fc3e37 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -76,8 +76,20 @@ class PatchManager: from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import ( patch_prepare_from_posids, ) + from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + patch_evaluation_loop, + patch_maybe_log_save_evaluate, + ) + + patch_fsdp2 = ( + self.cfg.torch_compile + and self.cfg.fsdp_config + and self.cfg.fsdp_version == 2 + ) patch_prepare_from_posids() + patch_evaluation_loop(patch_fsdp2) + patch_maybe_log_save_evaluate() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py deleted file mode 100644 index 8488a16df..000000000 --- a/src/axolotl/monkeypatch/trainer_eval_guard.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -fix for FSDP2 evals when using torch.compile -""" - -import inspect - -from transformers import Trainer - -from axolotl.monkeypatch.utils import detab_code -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - -ORIGINAL_TRAINER_CODE = """ - model.eval() -""" - -PATCHED_TRAINER_CODE = """ - if hasattr(model, "eval") and callable(model.eval): - self.model.eval() -""" - - -def get_evaluation_loop_code() -> str: - training_loop = inspect.getsource(Trainer.evaluation_loop) - return training_loop - - -def check_evaluation_loop_is_patchable() -> bool: - eval_loop = get_evaluation_loop_code() - eval_loop, _ = detab_code(eval_loop) - return ORIGINAL_TRAINER_CODE in eval_loop - - -def patch_evaluation_loop_for_fsdp2(): - """ - monkeypatch for fixing the eval loop for fsdp2 with torch.compile - """ - - try: - evaluation_loop = get_evaluation_loop_code() - except OSError: - return - Trainer._original_evaluation_loop = ( # pylint: disable=protected-access - evaluation_loop - ) - evaluation_loop, _ = detab_code(evaluation_loop) - if ORIGINAL_TRAINER_CODE not in evaluation_loop: - return - - evaluation_loop = evaluation_loop.replace( - ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE - ) - evaluation_loop = evaluation_loop.replace( - "def evaluation_loop(", - "def _fixed_evaluation_loop(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in evaluation_loop: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer.evaluation_loop = ( # pylint: disable=protected-access - _fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 - ) diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py new file mode 100644 index 000000000..75f4158b3 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -0,0 +1,165 @@ +""" +Module for patching transformers Trainer loss calculation to use nanmean. + +This is needed for context parallelism since chunks of the input sequences may be fully +masked and return NaNs in the loss calculation. + +Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with +the other evaluation_loop patch because we can't patch the same code twice without +raising an OSError. +""" + +import importlib +import inspect + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +ORIGINAL_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()', +} +PATCHED_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', +} + +ORIGINAL_FSDP2_CODE = """ + model.eval() +""" + +PATCHED_FSDP2_CODE = """ + if hasattr(model, "eval") and callable(model.eval): + self.model.eval() +""" + +ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" +PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" + + +def check_evaluation_loop_is_patchable() -> bool: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values()) + + +def check_evaluation_loop_is_fsdp2_patchable() -> bool: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + evaluation_loop_source, _ = detab_code(evaluation_loop_source) + return ORIGINAL_FSDP2_CODE in evaluation_loop_source + + +# pylint: disable=protected-access +def patch_evaluation_loop(patch_fsdp2: bool): + """Patch the evaluation_loop method.""" + # Check if already patched + if hasattr(Trainer, "_original_evaluation_loop"): + LOG.info("Trainer.evaluation_loop already patched") + return + + # Check if the patterns exist + try: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + except OSError: + return + Trainer.evaluation = evaluation_loop_source + evaluation_loop_source, _ = detab_code(evaluation_loop_source) + + # Apply the nanmean patches + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"] + ) + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"] + ) + + # Apply FSDP2 eval guard patch if needed + if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source: + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE + ) + LOG.info("Applied FSDP2 eval guard patch to evaluation_loop") + + # Rename the function to avoid conflicts + evaluation_loop_source = evaluation_loop_source.replace( + "def evaluation_loop(", + "def axolotl_evaluation_loop(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in evaluation_loop_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation") + Trainer.evaluation_loop = ( + axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 + ) + + +def check_maybe_log_save_evaluate_is_patchable() -> bool: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + return ORIGINAL_MAYBE_CODE in maybe_log_source + + +# pylint: disable=protected-access +def patch_maybe_log_save_evaluate(): + """Patch the _maybe_log_save_evaluate method.""" + # Check if already patched + if hasattr(Trainer, "_original_maybe_log_save_evaluate"): + LOG.info("Trainer._maybe_log_save_evaluate already patched") + return + + # Check if the patterns exist + try: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + except OSError: + return + Trainer._original_maybe_log_save_evaluate = maybe_log_source + maybe_log_source, _ = detab_code(maybe_log_source) + + # Apply the patch + maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE) + + # Rename the function to avoid conflicts + maybe_log_source = maybe_log_source.replace( + "def _maybe_log_save_evaluate(", + "def axolotl_maybe_log_save_evaluate(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in maybe_log_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation") + Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 26634cbbe..e424cb55a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -15,7 +15,6 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.logging import get_logger @@ -687,8 +686,6 @@ def setup_trainer( """ from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder - if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2: - patch_evaluation_loop_for_fsdp2() if cfg.rl: trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.model_ref = model_ref diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py new file mode 100644 index 000000000..de3e92621 --- /dev/null +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -0,0 +1,28 @@ +"""Unit tests for trainer loss calc monkeypatch.""" + +import unittest + +from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + check_evaluation_loop_is_fsdp2_patchable, + check_evaluation_loop_is_patchable, + check_maybe_log_save_evaluate_is_patchable, +) + + +class TestTrainerLossCalc(unittest.TestCase): + """ + Unit test class for trainer loss calc monkeypatch + """ + + def test_trainer_loss_calc_is_patchable(self): + """ + Test that the upstream transformers code is still patchable. This will fail if + the patched code changes upstream. + """ + assert check_evaluation_loop_is_patchable() + assert check_evaluation_loop_is_fsdp2_patchable() + assert check_maybe_log_save_evaluate_is_patchable() + + +if __name__ == "__main__": + unittest.main()