use nanmean for loss aggregation (CP fix) (#3033)

* use nanmena for loss aggregation (CP fix)

* use regular asserts

* small changes to make tests isolate

* combining evaluation_loop patches

* fix

* delete unused

* fix check
This commit is contained in:
Dan Saunders
2025-08-08 08:15:17 -04:00
committed by GitHub
parent 2974670bf8
commit 0ae06d756d
6 changed files with 207 additions and 83 deletions

View File

@@ -0,0 +1,28 @@
"""Unit tests for trainer loss calc monkeypatch."""
import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)
class TestTrainerLossCalc(unittest.TestCase):
"""
Unit test class for trainer loss calc monkeypatch
"""
def test_trainer_loss_calc_is_patchable(self):
"""
Test that the upstream transformers code is still patchable. This will fail if
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()
if __name__ == "__main__":
unittest.main()