coderabbit comments

This commit is contained in:
Dan Saunders
2025-06-07 04:50:29 +00:00
parent 657bffd85f
commit 345a159796
4 changed files with 47 additions and 56 deletions

View File

@@ -37,67 +37,63 @@ class TelemetryCallback(TrainerCallback):
self.last_report_time = None
self.last_report_step = 0
# pylint: disable=unused-argument
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-start")
# pylint: disable=unused-argument
def on_train_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
args: TrainingArguments,
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs,
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-end",
properties={
"loss": (
state.log_history[-1].get("loss", 0) if state.log_history else None
),
"learning_rate": (
state.log_history[-1].get("learning_rate", 0)
if state.log_history
else None
),
}
properties=self._extract_last_metrics(state)
| self.tracker.metrics.to_dict(),
)
# pylint: disable=unused-argument
def on_epoch_begin(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_epoch_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
args: TrainingArguments,
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs,
):
"""Handle step end."""
step = state.global_step
@@ -118,7 +114,7 @@ class TelemetryCallback(TrainerCallback):
time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed to avoid flooding
# Only report if enough time has passed
if (
step == 1
or time_since_last_report >= TIME_SINCE_LAST
@@ -133,20 +129,11 @@ class TelemetryCallback(TrainerCallback):
# Update memory metrics
self.tracker.update_memory_metrics()
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
learning_rate = (
state.log_history[-1].get("learning_rate", 0)
if state.log_history
else 0
)
# Prepare metrics to report
metrics = {
metrics = self._extract_last_metrics(state) | {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"loss": loss,
"learning_rate": learning_rate,
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
@@ -164,3 +151,14 @@ class TelemetryCallback(TrainerCallback):
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss and learning_rate from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
}

View File

@@ -127,7 +127,7 @@ def send_errors(func: Callable) -> Callable:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.=
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True

View File

@@ -172,8 +172,7 @@ class TelemetryManager:
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Tuple containing:
- Boolean denoting whether telemetry is enabled or not.
Boolean denoting whether telemetry is enabled or not.
"""
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
@@ -233,7 +232,10 @@ class TelemetryManager:
"""
# NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org?
org = value.split("/")[0]
parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in self.whitelist["organizations"]
return whitelisted
@@ -406,7 +408,8 @@ class TelemetryManager:
def send_system_info(self):
"""Helper method for sending system info"""
self.send_event(event_type="system-info", properties=self.system_info)
if self.system_info is not None:
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""

View File

@@ -50,6 +50,7 @@ except ImportError:
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_model_and_tokenizer(
@@ -532,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
model_ref=model_ref,
peft_config=peft_config,
)
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
return (
trainer,
@@ -567,17 +569,6 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
@@ -604,7 +595,6 @@ def train(
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
plugin_manager.post_train(cfg, model)
PLUGIN_MANAGER.post_train(cfg, model)
return model, tokenizer, trainer