diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py index c788d3174..aff64d146 100644 --- a/src/axolotl/telemetry/callbacks.py +++ b/src/axolotl/telemetry/callbacks.py @@ -37,67 +37,63 @@ class TelemetryCallback(TrainerCallback): self.last_report_time = None self.last_report_step = 0 + # pylint: disable=unused-argument def on_train_begin( self, args: TrainingArguments, - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + **kwargs, ): """Handle training start.""" self.telemetry_manager.send_event(event_type="train-start") + # pylint: disable=unused-argument def on_train_end( self, - args: TrainingArguments, # pylint: disable=unused-argument + args: TrainingArguments, state: TrainerState, - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, ): """Handle training end.""" # Send training completion event self.telemetry_manager.send_event( event_type="train-end", - properties={ - "loss": ( - state.log_history[-1].get("loss", 0) if state.log_history else None - ), - "learning_rate": ( - state.log_history[-1].get("learning_rate", 0) - if state.log_history - else None - ), - } + properties=self._extract_last_metrics(state) | self.tracker.metrics.to_dict(), ) + # pylint: disable=unused-argument def on_epoch_begin( self, - args: TrainingArguments, # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ): """Handle epoch start.""" self.current_epoch += 1 self.tracker.start_epoch(self.current_epoch) + # pylint: disable=unused-argument def on_epoch_end( self, - args: TrainingArguments, # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ): """Handle epoch end.""" self.tracker.end_epoch(self.current_epoch) + # pylint: disable=unused-argument def on_step_end( self, - args: TrainingArguments, # pylint: disable=unused-argument + args: TrainingArguments, state: TrainerState, - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, ): """Handle step end.""" step = state.global_step @@ -118,7 +114,7 @@ class TelemetryCallback(TrainerCallback): time_since_last_report = current_time - self.start_time steps_since_last_report = step - self.last_report_step - # Only report if enough time has passed to avoid flooding + # Only report if enough time has passed if ( step == 1 or time_since_last_report >= TIME_SINCE_LAST @@ -133,20 +129,11 @@ class TelemetryCallback(TrainerCallback): # Update memory metrics self.tracker.update_memory_metrics() - loss = state.log_history[-1].get("loss", 0) if state.log_history else 0 - learning_rate = ( - state.log_history[-1].get("learning_rate", 0) - if state.log_history - else 0 - ) - # Prepare metrics to report - metrics = { + metrics = self._extract_last_metrics(state) | { "step": step, "epoch": self.current_epoch, "progress": state.epoch, # Fractional epoch progress - "loss": loss, - "learning_rate": learning_rate, "steps_per_second": steps_per_second, "elapsed_time": current_time - self.start_time, "time_since_last_report": time_since_last_report, @@ -164,3 +151,14 @@ class TelemetryCallback(TrainerCallback): # Update last report time and step self.last_report_time = current_time self.last_report_step = step + + def _extract_last_metrics(self, state: TrainerState) -> dict: + """Extract last loss and learning_rate from log history.""" + if not state.log_history: + return {"loss": 0, "learning_rate": 0} + + last_log = state.log_history[-1] + return { + "loss": last_log.get("loss", 0), + "learning_rate": last_log.get("learning_rate", 0), + } diff --git a/src/axolotl/telemetry/errors.py b/src/axolotl/telemetry/errors.py index 98acd6a2c..27f2d2192 100644 --- a/src/axolotl/telemetry/errors.py +++ b/src/axolotl/telemetry/errors.py @@ -127,7 +127,7 @@ def send_errors(func: Callable) -> Callable: return func(*args, **kwargs) except Exception as exception: # Only track if we're not already handling an error. This prevents us from - # capturing an error more than once in nested decorated function calls.= + # capturing an error more than once in nested decorated function calls. global ERROR_HANDLED # pylint: disable=global-statement if not ERROR_HANDLED: ERROR_HANDLED = True diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index b8aa5c2ad..dde97dad9 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -172,8 +172,7 @@ class TelemetryManager: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html. Returns: - Tuple containing: - - Boolean denoting whether telemetry is enabled or not. + Boolean denoting whether telemetry is enabled or not. """ # Parse relevant env vars axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") @@ -233,7 +232,10 @@ class TelemetryManager: """ # NOTE: This membership-checking logic can be improved. # What happens when a local model path matches a whitelisted org? - org = value.split("/")[0] + parts = value.split("/") + if len(parts) < 2: + return False + org = parts[0] whitelisted = org.lower() in self.whitelist["organizations"] return whitelisted @@ -406,7 +408,8 @@ class TelemetryManager: def send_system_info(self): """Helper method for sending system info""" - self.send_event(event_type="system-info", properties=self.system_info) + if self.system_info is not None: + self.send_event(event_type="system-info", properties=self.system_info) def shutdown(self): """Ensure all queued events are processed before shutdown""" diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d341e3c77..f4f4a5a91 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -50,6 +50,7 @@ except ImportError: LOG = get_logger(__name__) TELEMETRY_MANAGER = TelemetryManager.get_instance() +PLUGIN_MANAGER = PluginManager.get_instance() def setup_model_and_tokenizer( @@ -532,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> model_ref=model_ref, peft_config=peft_config, ) + PLUGIN_MANAGER.post_trainer_create(cfg, trainer) return ( trainer, @@ -567,17 +569,6 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) - TELEMETRY_MANAGER.send_event( - event_type="model-load", properties=model.config.to_dict() - ) - if peft_config: - TELEMETRY_MANAGER.send_event( - event_type="peft-config-load", properties=peft_config.to_dict() - ) - - plugin_manager = PluginManager.get_instance() - plugin_manager.post_trainer_create(cfg, trainer) - # Determine if we need to resume from a checkpoint resume_from_checkpoint = determine_resume_checkpoint(cfg) @@ -604,7 +595,6 @@ def train( create_model_card(cfg, trainer) if not cfg.use_ray: cleanup_distributed() - - plugin_manager.post_train(cfg, model) + PLUGIN_MANAGER.post_train(cfg, model) return model, tokenizer, trainer