feat: add trackio as experiment tracking integration (#3253)
* feat: add trackio as experiment tracking integration - Add TrackioConfig to integrations schema with project_name, run_name, and space_id - Create trackio_.py module for environment setup - Add is_trackio_available() utility function - Integrate trackio with report_to in trainer builder - Add trackio callback for experiment tracking - Add trackio config keys to gpt-oss example YAMLs - Trackio runs locally by default, syncs to HF Space if space_id provided * changes * changes * changes * changes * changes * changes * changes * Update requirements.txt * don't allow pydantic 2.12 for now --------- Co-authored-by: Abubakar Abid <aaabid93@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -32,6 +32,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -29,6 +29,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -20,15 +20,16 @@ deepspeed>=0.17.0
|
||||
trl==0.25.0
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trackio
|
||||
trackio>=0.13.0
|
||||
typing_extensions>=4.14.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.49.1
|
||||
gradio>=6.2.0,<7.0
|
||||
|
||||
modal==1.0.2
|
||||
pydantic>=2.10.6
|
||||
pydantic>=2.10.6,<2.12
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
|
||||
@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -246,6 +247,7 @@ def load_cfg(
|
||||
setup_wandb_env_vars(cfg)
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
setup_trackio_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
|
||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -35,6 +35,7 @@ from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
is_trackio_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_trackio and is_trackio_available():
|
||||
from axolotl.utils.callbacks.trackio_ import (
|
||||
SaveAxolotlConfigtoTrackioCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
@@ -434,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
if self.cfg.use_trackio:
|
||||
report_to.append("trackio")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
@@ -441,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
elif self.cfg.use_trackio:
|
||||
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ def is_opentelemetry_available():
|
||||
)
|
||||
|
||||
|
||||
def is_trackio_available():
|
||||
return importlib.util.find_spec("trackio") is not None
|
||||
|
||||
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
|
||||
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Trackio module for trainer callbacks"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trackio
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.environment import is_package_version_ge
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
|
||||
"""Callback for trackio integration"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments",
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
if not is_package_version_ge("trackio", "0.11.0"):
|
||||
LOG.warning(
|
||||
"Trackio version 0.11.0 or higher is required to save config files. "
|
||||
"Please upgrade trackio: pip install --upgrade trackio"
|
||||
)
|
||||
return control
|
||||
|
||||
trackio.save(self.axolotl_config_path)
|
||||
LOG.info("The Axolotl config has been saved to Trackio.")
|
||||
except (FileNotFoundError, ConnectionError, AttributeError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
|
||||
return control
|
||||
@@ -34,6 +34,7 @@ from axolotl.utils.schemas.integrations import (
|
||||
MLFlowConfig,
|
||||
OpenTelemetryConfig,
|
||||
RayConfig,
|
||||
TrackioConfig,
|
||||
WandbConfig,
|
||||
)
|
||||
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
||||
@@ -63,6 +64,7 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
TrackioConfig,
|
||||
OpenTelemetryConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
|
||||
@@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel):
|
||||
"description": "Port for the Prometheus metrics HTTP server"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrackioConfig(BaseModel):
|
||||
"""Trackio configuration subset"""
|
||||
|
||||
use_trackio: bool | None = None
|
||||
trackio_project_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Your trackio project name"},
|
||||
)
|
||||
trackio_run_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your trackio run"},
|
||||
)
|
||||
trackio_space_id: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
|
||||
},
|
||||
)
|
||||
|
||||
17
src/axolotl/utils/trackio_.py
Normal file
17
src/axolotl/utils/trackio_.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Module for trackio utilities"""
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def setup_trackio_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("trackio_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
os.environ[key.upper()] = value
|
||||
|
||||
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
|
||||
cfg.use_trackio = True
|
||||
Reference in New Issue
Block a user