feat: add trackio as experiment tracking integration (#3253)
* feat: add trackio as experiment tracking integration - Add TrackioConfig to integrations schema with project_name, run_name, and space_id - Create trackio_.py module for environment setup - Add is_trackio_available() utility function - Integrate trackio with report_to in trainer builder - Add trackio callback for experiment tracking - Add trackio config keys to gpt-oss example YAMLs - Trackio runs locally by default, syncs to HF Space if space_id provided * changes * changes * changes * changes * changes * changes * changes * Update requirements.txt * don't allow pydantic 2.12 for now --------- Co-authored-by: Abubakar Abid <aaabid93@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -32,6 +32,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -20,15 +20,16 @@ deepspeed>=0.17.0
|
|||||||
trl==0.25.0
|
trl==0.25.0
|
||||||
hf_xet==1.2.0
|
hf_xet==1.2.0
|
||||||
kernels>=0.9.0
|
kernels>=0.9.0
|
||||||
trackio
|
trackio>=0.13.0
|
||||||
|
typing_extensions>=4.14.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.49.1
|
gradio>=6.2.0,<7.0
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic>=2.10.6
|
pydantic>=2.10.6,<2.12
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.tee import prepare_debug_log
|
from axolotl.utils.tee import prepare_debug_log
|
||||||
|
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -246,6 +247,7 @@ def load_cfg(
|
|||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
|
setup_trackio_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||||
|
|||||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
|||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
|||||||
outputs=[masked_preview, html_out],
|
outputs=[masked_preview, html_out],
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from axolotl.utils import (
|
|||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_opentelemetry_available,
|
is_opentelemetry_available,
|
||||||
|
is_trackio_available,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_trackio and is_trackio_available():
|
||||||
|
from axolotl.utils.callbacks.trackio_ import (
|
||||||
|
SaveAxolotlConfigtoTrackioCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||||
from axolotl.utils.callbacks.opentelemetry import (
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
OpenTelemetryMetricsCallback,
|
OpenTelemetryMetricsCallback,
|
||||||
@@ -434,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
report_to.append("tensorboard")
|
report_to.append("tensorboard")
|
||||||
if self.cfg.use_comet:
|
if self.cfg.use_comet:
|
||||||
report_to.append("comet_ml")
|
report_to.append("comet_ml")
|
||||||
|
if self.cfg.use_trackio:
|
||||||
|
report_to.append("trackio")
|
||||||
|
|
||||||
training_args_kwargs["report_to"] = report_to
|
training_args_kwargs["report_to"] = report_to
|
||||||
|
|
||||||
@@ -441,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
elif self.cfg.use_mlflow:
|
elif self.cfg.use_mlflow:
|
||||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
|
elif self.cfg.use_trackio:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["run_name"] = None
|
training_args_kwargs["run_name"] = None
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ def is_opentelemetry_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_trackio_available():
|
||||||
|
return importlib.util.find_spec("trackio") is not None
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_version() -> tuple[int, int, int]:
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Get Pytorch version as a tuple of (major, minor, patch).
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
|||||||
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Trackio module for trainer callbacks"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import trackio
|
||||||
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.environment import is_package_version_ge
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
|
||||||
|
"""Callback for trackio integration"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: "AxolotlTrainingArguments",
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
try:
|
||||||
|
if not is_package_version_ge("trackio", "0.11.0"):
|
||||||
|
LOG.warning(
|
||||||
|
"Trackio version 0.11.0 or higher is required to save config files. "
|
||||||
|
"Please upgrade trackio: pip install --upgrade trackio"
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
|
trackio.save(self.axolotl_config_path)
|
||||||
|
LOG.info("The Axolotl config has been saved to Trackio.")
|
||||||
|
except (FileNotFoundError, ConnectionError, AttributeError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
|
||||||
|
return control
|
||||||
@@ -34,6 +34,7 @@ from axolotl.utils.schemas.integrations import (
|
|||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
OpenTelemetryConfig,
|
OpenTelemetryConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
|
TrackioConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
||||||
@@ -63,6 +64,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
|
TrackioConfig,
|
||||||
OpenTelemetryConfig,
|
OpenTelemetryConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
|
|||||||
@@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel):
|
|||||||
"description": "Port for the Prometheus metrics HTTP server"
|
"description": "Port for the Prometheus metrics HTTP server"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrackioConfig(BaseModel):
|
||||||
|
"""Trackio configuration subset"""
|
||||||
|
|
||||||
|
use_trackio: bool | None = None
|
||||||
|
trackio_project_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Your trackio project name"},
|
||||||
|
)
|
||||||
|
trackio_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Set the name of your trackio run"},
|
||||||
|
)
|
||||||
|
trackio_space_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
17
src/axolotl/utils/trackio_.py
Normal file
17
src/axolotl/utils/trackio_.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Module for trackio utilities"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def setup_trackio_env_vars(cfg: DictDefault):
|
||||||
|
for key in cfg.keys():
|
||||||
|
if key.startswith("trackio_"):
|
||||||
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value and isinstance(value, str) and len(value) > 0:
|
||||||
|
os.environ[key.upper()] = value
|
||||||
|
|
||||||
|
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
|
||||||
|
cfg.use_trackio = True
|
||||||
Reference in New Issue
Block a user