coderabbit comments
This commit is contained in:
@@ -37,67 +37,63 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
self.last_report_time = None
|
self.last_report_time = None
|
||||||
self.last_report_step = 0
|
self.last_report_step = 0
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def on_train_begin(
|
def on_train_begin(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
state: TrainerState,
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Handle training start."""
|
"""Handle training start."""
|
||||||
self.telemetry_manager.send_event(event_type="train-start")
|
self.telemetry_manager.send_event(event_type="train-start")
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def on_train_end(
|
def on_train_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
args: TrainingArguments,
|
||||||
state: TrainerState,
|
state: TrainerState,
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Handle training end."""
|
"""Handle training end."""
|
||||||
# Send training completion event
|
# Send training completion event
|
||||||
self.telemetry_manager.send_event(
|
self.telemetry_manager.send_event(
|
||||||
event_type="train-end",
|
event_type="train-end",
|
||||||
properties={
|
properties=self._extract_last_metrics(state)
|
||||||
"loss": (
|
|
||||||
state.log_history[-1].get("loss", 0) if state.log_history else None
|
|
||||||
),
|
|
||||||
"learning_rate": (
|
|
||||||
state.log_history[-1].get("learning_rate", 0)
|
|
||||||
if state.log_history
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
| self.tracker.metrics.to_dict(),
|
| self.tracker.metrics.to_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def on_epoch_begin(
|
def on_epoch_begin(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
args: TrainingArguments,
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
state: TrainerState,
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Handle epoch start."""
|
"""Handle epoch start."""
|
||||||
self.current_epoch += 1
|
self.current_epoch += 1
|
||||||
self.tracker.start_epoch(self.current_epoch)
|
self.tracker.start_epoch(self.current_epoch)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def on_epoch_end(
|
def on_epoch_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
args: TrainingArguments,
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
state: TrainerState,
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Handle epoch end."""
|
"""Handle epoch end."""
|
||||||
self.tracker.end_epoch(self.current_epoch)
|
self.tracker.end_epoch(self.current_epoch)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
args: TrainingArguments,
|
||||||
state: TrainerState,
|
state: TrainerState,
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Handle step end."""
|
"""Handle step end."""
|
||||||
step = state.global_step
|
step = state.global_step
|
||||||
@@ -118,7 +114,7 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
time_since_last_report = current_time - self.start_time
|
time_since_last_report = current_time - self.start_time
|
||||||
steps_since_last_report = step - self.last_report_step
|
steps_since_last_report = step - self.last_report_step
|
||||||
|
|
||||||
# Only report if enough time has passed to avoid flooding
|
# Only report if enough time has passed
|
||||||
if (
|
if (
|
||||||
step == 1
|
step == 1
|
||||||
or time_since_last_report >= TIME_SINCE_LAST
|
or time_since_last_report >= TIME_SINCE_LAST
|
||||||
@@ -133,20 +129,11 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
# Update memory metrics
|
# Update memory metrics
|
||||||
self.tracker.update_memory_metrics()
|
self.tracker.update_memory_metrics()
|
||||||
|
|
||||||
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
|
|
||||||
learning_rate = (
|
|
||||||
state.log_history[-1].get("learning_rate", 0)
|
|
||||||
if state.log_history
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare metrics to report
|
# Prepare metrics to report
|
||||||
metrics = {
|
metrics = self._extract_last_metrics(state) | {
|
||||||
"step": step,
|
"step": step,
|
||||||
"epoch": self.current_epoch,
|
"epoch": self.current_epoch,
|
||||||
"progress": state.epoch, # Fractional epoch progress
|
"progress": state.epoch, # Fractional epoch progress
|
||||||
"loss": loss,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"steps_per_second": steps_per_second,
|
"steps_per_second": steps_per_second,
|
||||||
"elapsed_time": current_time - self.start_time,
|
"elapsed_time": current_time - self.start_time,
|
||||||
"time_since_last_report": time_since_last_report,
|
"time_since_last_report": time_since_last_report,
|
||||||
@@ -164,3 +151,14 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
# Update last report time and step
|
# Update last report time and step
|
||||||
self.last_report_time = current_time
|
self.last_report_time = current_time
|
||||||
self.last_report_step = step
|
self.last_report_step = step
|
||||||
|
|
||||||
|
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||||
|
"""Extract last loss and learning_rate from log history."""
|
||||||
|
if not state.log_history:
|
||||||
|
return {"loss": 0, "learning_rate": 0}
|
||||||
|
|
||||||
|
last_log = state.log_history[-1]
|
||||||
|
return {
|
||||||
|
"loss": last_log.get("loss", 0),
|
||||||
|
"learning_rate": last_log.get("learning_rate", 0),
|
||||||
|
}
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def send_errors(func: Callable) -> Callable:
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
# Only track if we're not already handling an error. This prevents us from
|
# Only track if we're not already handling an error. This prevents us from
|
||||||
# capturing an error more than once in nested decorated function calls.=
|
# capturing an error more than once in nested decorated function calls.
|
||||||
global ERROR_HANDLED # pylint: disable=global-statement
|
global ERROR_HANDLED # pylint: disable=global-statement
|
||||||
if not ERROR_HANDLED:
|
if not ERROR_HANDLED:
|
||||||
ERROR_HANDLED = True
|
ERROR_HANDLED = True
|
||||||
|
|||||||
@@ -172,8 +172,7 @@ class TelemetryManager:
|
|||||||
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Boolean denoting whether telemetry is enabled or not.
|
||||||
- Boolean denoting whether telemetry is enabled or not.
|
|
||||||
"""
|
"""
|
||||||
# Parse relevant env vars
|
# Parse relevant env vars
|
||||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||||
@@ -233,7 +232,10 @@ class TelemetryManager:
|
|||||||
"""
|
"""
|
||||||
# NOTE: This membership-checking logic can be improved.
|
# NOTE: This membership-checking logic can be improved.
|
||||||
# What happens when a local model path matches a whitelisted org?
|
# What happens when a local model path matches a whitelisted org?
|
||||||
org = value.split("/")[0]
|
parts = value.split("/")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return False
|
||||||
|
org = parts[0]
|
||||||
whitelisted = org.lower() in self.whitelist["organizations"]
|
whitelisted = org.lower() in self.whitelist["organizations"]
|
||||||
|
|
||||||
return whitelisted
|
return whitelisted
|
||||||
@@ -406,7 +408,8 @@ class TelemetryManager:
|
|||||||
|
|
||||||
def send_system_info(self):
|
def send_system_info(self):
|
||||||
"""Helper method for sending system info"""
|
"""Helper method for sending system info"""
|
||||||
self.send_event(event_type="system-info", properties=self.system_info)
|
if self.system_info is not None:
|
||||||
|
self.send_event(event_type="system-info", properties=self.system_info)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Ensure all queued events are processed before shutdown"""
|
"""Ensure all queued events are processed before shutdown"""
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ except ImportError:
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer(
|
def setup_model_and_tokenizer(
|
||||||
@@ -532,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
model_ref=model_ref,
|
model_ref=model_ref,
|
||||||
peft_config=peft_config,
|
peft_config=peft_config,
|
||||||
)
|
)
|
||||||
|
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
trainer,
|
trainer,
|
||||||
@@ -567,17 +569,6 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(
|
|
||||||
event_type="model-load", properties=model.config.to_dict()
|
|
||||||
)
|
|
||||||
if peft_config:
|
|
||||||
TELEMETRY_MANAGER.send_event(
|
|
||||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
plugin_manager.post_trainer_create(cfg, trainer)
|
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
# Determine if we need to resume from a checkpoint
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
|
|
||||||
@@ -604,7 +595,6 @@ def train(
|
|||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
PLUGIN_MANAGER.post_train(cfg, model)
|
||||||
plugin_manager.post_train(cfg, model)
|
|
||||||
|
|
||||||
return model, tokenizer, trainer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user