feat: add trackio as experiment tracking integration (#3253)

* feat: add trackio as experiment tracking integration

- Add TrackioConfig to integrations schema with project_name, run_name, and space_id
- Create trackio_.py module for environment setup
- Add is_trackio_available() utility function
- Integrate trackio with report_to in trainer builder
- Add trackio callback for experiment tracking
- Add trackio config keys to gpt-oss example YAMLs
- Trackio runs locally by default, syncs to HF Space if space_id provided

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* Update requirements.txt

* don't allow pydantic 2.12 for now

---------

Co-authored-by: Abubakar Abid <aaabid93@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Abubakar Abid
2025-12-23 05:49:07 -08:00
committed by GitHub
parent 92ee4256f7
commit f2155eaf79
16 changed files with 134 additions and 7 deletions

View File

@@ -32,6 +32,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -29,6 +29,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -20,15 +20,16 @@ deepspeed>=0.17.0
trl==0.25.0
hf_xet==1.2.0
kernels>=0.9.0
trackio
trackio>=0.13.0
typing_extensions>=4.14.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.49.1
gradio>=6.2.0,<7.0
modal==1.0.2
pydantic>=2.10.6
pydantic>=2.10.6,<2.12
addict
fire
PyYAML>=6.0

View File

@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trackio_ import setup_trackio_env_vars
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -246,6 +247,7 @@ def load_cfg(
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
setup_trackio_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)

View File

@@ -288,8 +288,8 @@ def do_inference_gradio(
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
demo.launch(
footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
demo.launch(
footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -35,6 +35,7 @@ from axolotl.utils import (
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
is_trackio_available,
)
from axolotl.utils.callbacks import (
GCCallback,
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_trackio and is_trackio_available():
from axolotl.utils.callbacks.trackio_ import (
SaveAxolotlConfigtoTrackioCallback,
)
callbacks.append(
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
@@ -434,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
if self.cfg.use_trackio:
report_to.append("trackio")
training_args_kwargs["report_to"] = report_to
@@ -441,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
elif self.cfg.use_trackio:
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
else:
training_args_kwargs["run_name"] = None

View File

@@ -24,6 +24,10 @@ def is_opentelemetry_available():
)
def is_trackio_available():
return importlib.util.find_spec("trackio") is not None
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).

View File

@@ -0,0 +1,44 @@
"""Trackio module for trainer callbacks"""
from typing import TYPE_CHECKING
import trackio
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
from axolotl.utils.environment import is_package_version_ge
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
"""Callback for trackio integration"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments",
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if is_main_process():
try:
if not is_package_version_ge("trackio", "0.11.0"):
LOG.warning(
"Trackio version 0.11.0 or higher is required to save config files. "
"Please upgrade trackio: pip install --upgrade trackio"
)
return control
trackio.save(self.axolotl_config_path)
LOG.info("The Axolotl config has been saved to Trackio.")
except (FileNotFoundError, ConnectionError, AttributeError) as err:
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
return control

View File

@@ -34,6 +34,7 @@ from axolotl.utils.schemas.integrations import (
MLFlowConfig,
OpenTelemetryConfig,
RayConfig,
TrackioConfig,
WandbConfig,
)
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
@@ -63,6 +64,7 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
CometConfig,
TrackioConfig,
OpenTelemetryConfig,
LISAConfig,
GradioConfig,

View File

@@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel):
"description": "Port for the Prometheus metrics HTTP server"
},
)
class TrackioConfig(BaseModel):
"""Trackio configuration subset"""
use_trackio: bool | None = None
trackio_project_name: str | None = Field(
default=None,
json_schema_extra={"description": "Your trackio project name"},
)
trackio_run_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your trackio run"},
)
trackio_space_id: str | None = Field(
default=None,
json_schema_extra={
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
},
)

View File

@@ -0,0 +1,17 @@
"""Module for trackio utilities"""
import os
from axolotl.utils.dict import DictDefault
def setup_trackio_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("trackio_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
cfg.use_trackio = True