tests for runtime metrics telemetry and assoc. callback
This commit is contained in:
@@ -1,8 +0,0 @@
|
||||
"""Init for axolotl.telemetry module."""
|
||||
|
||||
from .manager import TelemetryConfig, TelemetryManager
|
||||
|
||||
__all__ = [
|
||||
"TelemetryConfig",
|
||||
"TelemetryManager",
|
||||
]
|
||||
@@ -34,7 +34,7 @@ class TelemetryCallback(TrainerCallback):
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self.current_epoch = -1
|
||||
self.start_time = time.time()
|
||||
self.last_report_time = self.start_time
|
||||
self.last_report_time = None
|
||||
self.last_report_step = 0
|
||||
|
||||
def on_train_begin(
|
||||
@@ -110,12 +110,16 @@ class TelemetryCallback(TrainerCallback):
|
||||
|
||||
if should_report:
|
||||
current_time = time.time()
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
if self.last_report_time is not None:
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
else:
|
||||
time_since_last_report = current_time - self.start_time
|
||||
steps_since_last_report = step - self.last_report_step
|
||||
|
||||
# Only report if enough time has passed to avoid flooding
|
||||
if (
|
||||
time_since_last_report >= TIME_SINCE_LAST
|
||||
step == 1
|
||||
or time_since_last_report >= TIME_SINCE_LAST
|
||||
or steps_since_last_report >= self.report_interval_steps
|
||||
):
|
||||
# Calculate steps per second for this interval
|
||||
|
||||
Reference in New Issue
Block a user