fix: robust handling of race condition on patching check (#3543) [skip ci]
This commit is contained in:
@@ -37,7 +37,10 @@ PATCHED_MAYBE_CODE = (
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
evaluation_loop_source = Trainer._original_evaluation_loop
|
||||
else:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||
|
||||
|
||||
@@ -53,7 +56,7 @@ def patch_evaluation_loop():
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
except OSError:
|
||||
return
|
||||
Trainer.evaluation = evaluation_loop_source
|
||||
Trainer._original_evaluation_loop = evaluation_loop_source
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
|
||||
# Apply the nanmean patches
|
||||
@@ -93,7 +96,10 @@ def patch_evaluation_loop():
|
||||
|
||||
|
||||
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
||||
maybe_log_source = Trainer._original_maybe_log_save_evaluate
|
||||
else:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user