diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index 3a99d0115..7704ceddc 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -37,7 +37,10 @@ PATCHED_MAYBE_CODE = ( def check_evaluation_loop_is_patchable() -> bool: - evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + if hasattr(Trainer, "_original_evaluation_loop"): + evaluation_loop_source = Trainer._original_evaluation_loop + else: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values()) @@ -53,7 +56,7 @@ def patch_evaluation_loop(): evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) except OSError: return - Trainer.evaluation = evaluation_loop_source + Trainer._original_evaluation_loop = evaluation_loop_source evaluation_loop_source, _ = detab_code(evaluation_loop_source) # Apply the nanmean patches @@ -93,7 +96,10 @@ def patch_evaluation_loop(): def check_maybe_log_save_evaluate_is_patchable() -> bool: - maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + if hasattr(Trainer, "_original_maybe_log_save_evaluate"): + maybe_log_source = Trainer._original_maybe_log_save_evaluate + else: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) return ORIGINAL_MAYBE_CODE in maybe_log_source