diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index 62f3167e8..b7082f986 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -32,6 +32,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml index ccb84e28e..b718ff2eb 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml @@ -28,6 +28,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml index 69a3c434d..af1c93bc0 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -29,6 +29,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml index 4a0f1ad70..894ba99b8 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -28,6 +28,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml index b6deacb1b..7c4f97846 100644 --- a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -41,6 +41,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml index ab026337d..cbb9efc8e 100644 --- a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml @@ -41,6 +41,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 diff --git a/requirements.txt b/requirements.txt index 093546815..5e1af6940 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,15 +20,16 @@ deepspeed>=0.17.0 trl==0.25.0 hf_xet==1.2.0 kernels>=0.9.0 -trackio +trackio>=0.13.0 +typing_extensions>=4.14.0 optimum==1.16.2 hf_transfer sentencepiece -gradio==5.49.1 +gradio>=6.2.0,<7.0 modal==1.0.2 -pydantic>=2.10.6 +pydantic>=2.10.6,<2.12 addict fire PyYAML>=6.0 diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index b53c6576b..986167f02 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.tee import prepare_debug_log +from axolotl.utils.trackio_ import setup_trackio_env_vars from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -246,6 +247,7 @@ def load_cfg( setup_wandb_env_vars(cfg) setup_mlflow_env_vars(cfg) setup_comet_env_vars(cfg) + setup_trackio_env_vars(cfg) plugin_set_cfg(cfg) TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 640be3696..cafa0f4ef 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -288,8 +288,8 @@ def do_inference_gradio( title=cfg.get("gradio_title", "Axolotl Gradio Interface"), ) - demo.queue().launch( - show_api=False, + demo.launch( + footer_links=["gradio", "settings"], share=cfg.get("gradio_share", True), server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_port=cfg.get("gradio_server_port", None), diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py index 1157bfd66..7bf68048e 100644 --- a/src/axolotl/cli/utils/diffusion.py +++ b/src/axolotl/cli/utils/diffusion.py @@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui( outputs=[masked_preview, html_out], ) - demo.queue().launch( - show_api=False, + demo.launch( + footer_links=["gradio", "settings"], share=cfg.get("gradio_share", True), server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_port=cfg.get("gradio_server_port", None), diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 06d15ffc8..412f6da2f 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -35,6 +35,7 @@ from axolotl.utils import ( is_comet_available, is_mlflow_available, is_opentelemetry_available, + is_trackio_available, ) from axolotl.utils.callbacks import ( GCCallback, @@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_trackio and is_trackio_available(): + from axolotl.utils.callbacks.trackio_ import ( + SaveAxolotlConfigtoTrackioCallback, + ) + + callbacks.append( + SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path) + ) if self.cfg.use_otel_metrics and is_opentelemetry_available(): from axolotl.utils.callbacks.opentelemetry import ( OpenTelemetryMetricsCallback, @@ -434,6 +443,8 @@ class TrainerBuilderBase(abc.ABC): report_to.append("tensorboard") if self.cfg.use_comet: report_to.append("comet_ml") + if self.cfg.use_trackio: + report_to.append("trackio") training_args_kwargs["report_to"] = report_to @@ -441,6 +452,8 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["run_name"] = self.cfg.wandb_name elif self.cfg.use_mlflow: training_args_kwargs["run_name"] = self.cfg.mlflow_run_name + elif self.cfg.use_trackio: + training_args_kwargs["run_name"] = self.cfg.trackio_run_name else: training_args_kwargs["run_name"] = None diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 335049158..96ac29bd0 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -24,6 +24,10 @@ def is_opentelemetry_available(): ) +def is_trackio_available(): + return importlib.util.find_spec("trackio") is not None + + def get_pytorch_version() -> tuple[int, int, int]: """ Get Pytorch version as a tuple of (major, minor, patch). diff --git a/src/axolotl/utils/callbacks/trackio_.py b/src/axolotl/utils/callbacks/trackio_.py new file mode 100644 index 000000000..8249321f6 --- /dev/null +++ b/src/axolotl/utils/callbacks/trackio_.py @@ -0,0 +1,44 @@ +"""Trackio module for trainer callbacks""" + +from typing import TYPE_CHECKING + +import trackio +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process +from axolotl.utils.environment import is_package_version_ge +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from axolotl.core.training_args import AxolotlTrainingArguments + +LOG = get_logger(__name__) + + +class SaveAxolotlConfigtoTrackioCallback(TrainerCallback): + """Callback for trackio integration""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if is_main_process(): + try: + if not is_package_version_ge("trackio", "0.11.0"): + LOG.warning( + "Trackio version 0.11.0 or higher is required to save config files. " + "Please upgrade trackio: pip install --upgrade trackio" + ) + return control + + trackio.save(self.axolotl_config_path) + LOG.info("The Axolotl config has been saved to Trackio.") + except (FileNotFoundError, ConnectionError, AttributeError) as err: + LOG.warning(f"Error while saving Axolotl config to Trackio: {err}") + return control diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f2f4a311a..4ef1aff3a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -34,6 +34,7 @@ from axolotl.utils.schemas.integrations import ( MLFlowConfig, OpenTelemetryConfig, RayConfig, + TrackioConfig, WandbConfig, ) from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities @@ -63,6 +64,7 @@ class AxolotlInputConfig( WandbConfig, MLFlowConfig, CometConfig, + TrackioConfig, OpenTelemetryConfig, LISAConfig, GradioConfig, diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 97d675569..dc171c310 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel): "description": "Port for the Prometheus metrics HTTP server" }, ) + + +class TrackioConfig(BaseModel): + """Trackio configuration subset""" + + use_trackio: bool | None = None + trackio_project_name: str | None = Field( + default=None, + json_schema_extra={"description": "Your trackio project name"}, + ) + trackio_run_name: str | None = Field( + default=None, + json_schema_extra={"description": "Set the name of your trackio run"}, + ) + trackio_space_id: str | None = Field( + default=None, + json_schema_extra={ + "description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)" + }, + ) diff --git a/src/axolotl/utils/trackio_.py b/src/axolotl/utils/trackio_.py new file mode 100644 index 000000000..2bddfb972 --- /dev/null +++ b/src/axolotl/utils/trackio_.py @@ -0,0 +1,17 @@ +"""Module for trackio utilities""" + +import os + +from axolotl.utils.dict import DictDefault + + +def setup_trackio_env_vars(cfg: DictDefault): + for key in cfg.keys(): + if key.startswith("trackio_"): + value = cfg.get(key, "") + + if value and isinstance(value, str) and len(value) > 0: + os.environ[key.upper()] = value + + if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0: + cfg.use_trackio = True