use nanmean for loss aggregation (CP fix) (#3033)
* use nanmena for loss aggregation (CP fix) * use regular asserts * small changes to make tests isolate * combining evaluation_loop patches * fix * delete unused * fix check
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""Model loader class implementation for loading, configuring, and patching various
|
"""
|
||||||
models.
|
Model loader class implementation for loading, configuring, and patching various models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
|||||||
@@ -76,8 +76,20 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||||
patch_prepare_from_posids,
|
patch_prepare_from_posids,
|
||||||
)
|
)
|
||||||
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
|
patch_evaluation_loop,
|
||||||
|
patch_maybe_log_save_evaluate,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_fsdp2 = (
|
||||||
|
self.cfg.torch_compile
|
||||||
|
and self.cfg.fsdp_config
|
||||||
|
and self.cfg.fsdp_version == 2
|
||||||
|
)
|
||||||
|
|
||||||
patch_prepare_from_posids()
|
patch_prepare_from_posids()
|
||||||
|
patch_evaluation_loop(patch_fsdp2)
|
||||||
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
"""
|
|
||||||
fix for FSDP2 evals when using torch.compile
|
|
||||||
"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
|
||||||
model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_TRAINER_CODE = """
|
|
||||||
if hasattr(model, "eval") and callable(model.eval):
|
|
||||||
self.model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_evaluation_loop_code() -> str:
|
|
||||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
|
||||||
return training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def check_evaluation_loop_is_patchable() -> bool:
|
|
||||||
eval_loop = get_evaluation_loop_code()
|
|
||||||
eval_loop, _ = detab_code(eval_loop)
|
|
||||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
|
||||||
|
|
||||||
|
|
||||||
def patch_evaluation_loop_for_fsdp2():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
evaluation_loop = get_evaluation_loop_code()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
|
||||||
evaluation_loop
|
|
||||||
)
|
|
||||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
|
||||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
|
||||||
return
|
|
||||||
|
|
||||||
evaluation_loop = evaluation_loop.replace(
|
|
||||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
|
||||||
)
|
|
||||||
evaluation_loop = evaluation_loop.replace(
|
|
||||||
"def evaluation_loop(",
|
|
||||||
"def _fixed_evaluation_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in evaluation_loop:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
|
||||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
|
||||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Module for patching transformers Trainer loss calculation to use nanmean.
|
||||||
|
|
||||||
|
This is needed for context parallelism since chunks of the input sequences may be fully
|
||||||
|
masked and return NaNs in the loss calculation.
|
||||||
|
|
||||||
|
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
|
||||||
|
the other evaluation_loop patch because we can't patch the same code twice without
|
||||||
|
raising an OSError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_EVAL_CODE = {
|
||||||
|
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
|
||||||
|
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
|
||||||
|
}
|
||||||
|
PATCHED_EVAL_CODE = {
|
||||||
|
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
|
||||||
|
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||||
|
}
|
||||||
|
|
||||||
|
ORIGINAL_FSDP2_CODE = """
|
||||||
|
model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_FSDP2_CODE = """
|
||||||
|
if hasattr(model, "eval") and callable(model.eval):
|
||||||
|
self.model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||||
|
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||||
|
|
||||||
|
|
||||||
|
def check_evaluation_loop_is_patchable() -> bool:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||||
|
|
||||||
|
|
||||||
|
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||||
|
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def patch_evaluation_loop(patch_fsdp2: bool):
|
||||||
|
"""Patch the evaluation_loop method."""
|
||||||
|
# Check if already patched
|
||||||
|
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||||
|
LOG.info("Trainer.evaluation_loop already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the patterns exist
|
||||||
|
try:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer.evaluation = evaluation_loop_source
|
||||||
|
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||||
|
|
||||||
|
# Apply the nanmean patches
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
|
||||||
|
)
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply FSDP2 eval guard patch if needed
|
||||||
|
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
||||||
|
)
|
||||||
|
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
||||||
|
|
||||||
|
# Rename the function to avoid conflicts
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
"def evaluation_loop(",
|
||||||
|
"def axolotl_evaluation_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the module for necessary imports
|
||||||
|
module_name = Trainer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Import necessary items from the module
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in evaluation_loop_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
# Execute the imports and patched method
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||||
|
Trainer.evaluation_loop = (
|
||||||
|
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
||||||
|
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||||
|
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def patch_maybe_log_save_evaluate():
|
||||||
|
"""Patch the _maybe_log_save_evaluate method."""
|
||||||
|
# Check if already patched
|
||||||
|
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
||||||
|
LOG.info("Trainer._maybe_log_save_evaluate already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the patterns exist
|
||||||
|
try:
|
||||||
|
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_maybe_log_save_evaluate = maybe_log_source
|
||||||
|
maybe_log_source, _ = detab_code(maybe_log_source)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
|
||||||
|
|
||||||
|
# Rename the function to avoid conflicts
|
||||||
|
maybe_log_source = maybe_log_source.replace(
|
||||||
|
"def _maybe_log_save_evaluate(",
|
||||||
|
"def axolotl_maybe_log_save_evaluate(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the module for necessary imports
|
||||||
|
module_name = Trainer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Import necessary items from the module
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in maybe_log_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
# Execute the imports and patched method
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||||
|
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
|
||||||
@@ -15,7 +15,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
|||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
|
||||||
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -687,8 +686,6 @@ def setup_trainer(
|
|||||||
"""
|
"""
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
|
|
||||||
patch_evaluation_loop_for_fsdp2()
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||||
trainer_builder.model_ref = model_ref
|
trainer_builder.model_ref = model_ref
|
||||||
|
|||||||
28
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
28
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Unit tests for trainer loss calc monkeypatch."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
|
check_evaluation_loop_is_fsdp2_patchable,
|
||||||
|
check_evaluation_loop_is_patchable,
|
||||||
|
check_maybe_log_save_evaluate_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerLossCalc(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Unit test class for trainer loss calc monkeypatch
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_trainer_loss_calc_is_patchable(self):
|
||||||
|
"""
|
||||||
|
Test that the upstream transformers code is still patchable. This will fail if
|
||||||
|
the patched code changes upstream.
|
||||||
|
"""
|
||||||
|
assert check_evaluation_loop_is_patchable()
|
||||||
|
assert check_evaluation_loop_is_fsdp2_patchable()
|
||||||
|
assert check_maybe_log_save_evaluate_is_patchable()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user