coderabbit comments

This commit is contained in:
Dan Saunders
2025-06-07 04:50:29 +00:00
parent 657bffd85f
commit 345a159796
4 changed files with 47 additions and 56 deletions

View File

@@ -37,67 +37,63 @@ class TelemetryCallback(TrainerCallback):
self.last_report_time = None self.last_report_time = None
self.last_report_step = 0 self.last_report_step = 0
# pylint: disable=unused-argument
def on_train_begin( def on_train_begin(
self, self,
args: TrainingArguments, args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
"""Handle training start.""" """Handle training start."""
self.telemetry_manager.send_event(event_type="train-start") self.telemetry_manager.send_event(event_type="train-start")
# pylint: disable=unused-argument
def on_train_end( def on_train_end(
self, self,
args: TrainingArguments, # pylint: disable=unused-argument args: TrainingArguments,
state: TrainerState, state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
"""Handle training end.""" """Handle training end."""
# Send training completion event # Send training completion event
self.telemetry_manager.send_event( self.telemetry_manager.send_event(
event_type="train-end", event_type="train-end",
properties={ properties=self._extract_last_metrics(state)
"loss": (
state.log_history[-1].get("loss", 0) if state.log_history else None
),
"learning_rate": (
state.log_history[-1].get("learning_rate", 0)
if state.log_history
else None
),
}
| self.tracker.metrics.to_dict(), | self.tracker.metrics.to_dict(),
) )
# pylint: disable=unused-argument
def on_epoch_begin( def on_epoch_begin(
self, self,
args: TrainingArguments, # pylint: disable=unused-argument args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
"""Handle epoch start.""" """Handle epoch start."""
self.current_epoch += 1 self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch) self.tracker.start_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_epoch_end( def on_epoch_end(
self, self,
args: TrainingArguments, # pylint: disable=unused-argument args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
"""Handle epoch end.""" """Handle epoch end."""
self.tracker.end_epoch(self.current_epoch) self.tracker.end_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_step_end( def on_step_end(
self, self,
args: TrainingArguments, # pylint: disable=unused-argument args: TrainingArguments,
state: TrainerState, state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
"""Handle step end.""" """Handle step end."""
step = state.global_step step = state.global_step
@@ -118,7 +114,7 @@ class TelemetryCallback(TrainerCallback):
time_since_last_report = current_time - self.start_time time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed to avoid flooding # Only report if enough time has passed
if ( if (
step == 1 step == 1
or time_since_last_report >= TIME_SINCE_LAST or time_since_last_report >= TIME_SINCE_LAST
@@ -133,20 +129,11 @@ class TelemetryCallback(TrainerCallback):
# Update memory metrics # Update memory metrics
self.tracker.update_memory_metrics() self.tracker.update_memory_metrics()
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
learning_rate = (
state.log_history[-1].get("learning_rate", 0)
if state.log_history
else 0
)
# Prepare metrics to report # Prepare metrics to report
metrics = { metrics = self._extract_last_metrics(state) | {
"step": step, "step": step,
"epoch": self.current_epoch, "epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress "progress": state.epoch, # Fractional epoch progress
"loss": loss,
"learning_rate": learning_rate,
"steps_per_second": steps_per_second, "steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time, "elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report, "time_since_last_report": time_since_last_report,
@@ -164,3 +151,14 @@ class TelemetryCallback(TrainerCallback):
# Update last report time and step # Update last report time and step
self.last_report_time = current_time self.last_report_time = current_time
self.last_report_step = step self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss and learning_rate from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
}

View File

@@ -127,7 +127,7 @@ def send_errors(func: Callable) -> Callable:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception as exception: except Exception as exception:
# Only track if we're not already handling an error. This prevents us from # Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.= # capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED: if not ERROR_HANDLED:
ERROR_HANDLED = True ERROR_HANDLED = True

View File

@@ -172,8 +172,7 @@ class TelemetryManager:
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html. https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns: Returns:
Tuple containing: Boolean denoting whether telemetry is enabled or not.
- Boolean denoting whether telemetry is enabled or not.
""" """
# Parse relevant env vars # Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
@@ -233,7 +232,10 @@ class TelemetryManager:
""" """
# NOTE: This membership-checking logic can be improved. # NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org? # What happens when a local model path matches a whitelisted org?
org = value.split("/")[0] parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in self.whitelist["organizations"] whitelisted = org.lower() in self.whitelist["organizations"]
return whitelisted return whitelisted
@@ -406,7 +408,8 @@ class TelemetryManager:
def send_system_info(self): def send_system_info(self):
"""Helper method for sending system info""" """Helper method for sending system info"""
self.send_event(event_type="system-info", properties=self.system_info) if self.system_info is not None:
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self): def shutdown(self):
"""Ensure all queued events are processed before shutdown""" """Ensure all queued events are processed before shutdown"""

View File

@@ -50,6 +50,7 @@ except ImportError:
LOG = get_logger(__name__) LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance() TELEMETRY_MANAGER = TelemetryManager.get_instance()
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_model_and_tokenizer( def setup_model_and_tokenizer(
@@ -532,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
model_ref=model_ref, model_ref=model_ref,
peft_config=peft_config, peft_config=peft_config,
) )
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
return ( return (
trainer, trainer,
@@ -567,17 +569,6 @@ def train(
processor, processor,
) = setup_model_and_trainer(cfg, dataset_meta) ) = setup_model_and_trainer(cfg, dataset_meta)
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Determine if we need to resume from a checkpoint # Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg) resume_from_checkpoint = determine_resume_checkpoint(cfg)
@@ -604,7 +595,6 @@ def train(
create_model_card(cfg, trainer) create_model_card(cfg, trainer)
if not cfg.use_ray: if not cfg.use_ray:
cleanup_distributed() cleanup_distributed()
PLUGIN_MANAGER.post_train(cfg, model)
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer return model, tokenizer, trainer