* use nanmena for loss aggregation (CP fix) * use regular asserts * small changes to make tests isolate * combining evaluation_loop patches * fix * delete unused * fix check
29 lines
814 B
Python
29 lines
814 B
Python
"""Unit tests for trainer loss calc monkeypatch."""
|
|
|
|
import unittest
|
|
|
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
|
check_evaluation_loop_is_fsdp2_patchable,
|
|
check_evaluation_loop_is_patchable,
|
|
check_maybe_log_save_evaluate_is_patchable,
|
|
)
|
|
|
|
|
|
class TestTrainerLossCalc(unittest.TestCase):
|
|
"""
|
|
Unit test class for trainer loss calc monkeypatch
|
|
"""
|
|
|
|
def test_trainer_loss_calc_is_patchable(self):
|
|
"""
|
|
Test that the upstream transformers code is still patchable. This will fail if
|
|
the patched code changes upstream.
|
|
"""
|
|
assert check_evaluation_loop_is_patchable()
|
|
assert check_evaluation_loop_is_fsdp2_patchable()
|
|
assert check_maybe_log_save_evaluate_is_patchable()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|