From b54f9c942ba32c329c41096cd5cd8d675e395e36 Mon Sep 17 00:00:00 2001 From: Eduard Zl <22470747+eduardzl@users.noreply.github.com> Date: Tue, 11 Nov 2025 04:04:28 +0200 Subject: [PATCH 01/14] _get_tools in ChatTemplateStrategy : function "parameters" can be dict or string (#3238) * When training of function calls, "tools" elements of a dataset can contain same parameter name but with different types. Datasets fails to load such training set. This fix allows "parameters" element of function call to be string( by running "json.dumps" in preparation of training data set). The _get_tools function will iterate over tool definitions, if "parameters" element is dict, it will keep that way, if it is a string, it will be converted to dict by invoking "json.loads" on string value. * feat: add doc on tool parameters json loading * feat: add tests for parameters json string --------- Co-authored-by: ezlotnik Co-authored-by: NanoCode012 --- docs/dataset-formats/conversation.qmd | 7 + .../prompt_strategies/chat_template.py | 17 + ...at_templates_tool_call_string_arguments.py | 295 +++++++++++++++++- 3 files changed, 317 insertions(+), 2 deletions(-) diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 870a2b67d..34fde45fb 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -218,6 +218,13 @@ If you have tool arguments with same name but different dtypes (like `"time": st ``` "arguments": "{\"...\": \"...\"}" ``` + +The same is applicable for tool parameters. + +``` +"parameters": "{\"...\": \"...\"}" +``` + ::: Example config for Llama4: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index f4dcbd7cd..28155810f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -823,6 +823,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return None if isinstance(tools, list): + # Process each tool to handle JSON string parameters + for tool in tools: + if isinstance(tool, dict) and "function" in tool: + function = tool["function"] + if "parameters" in function: + params = function["parameters"] + if isinstance(params, str): + try: + function["parameters"] = json.loads(params) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool parameters as JSON. " + f"Function: {function.get('name', 'unknown')}, " + f"Parameters string: {params!r}, " + f"Error: {e}" + ) + raise return tools raise ValueError( diff --git a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py index 7de21b940..5866cc367 100644 --- a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py +++ b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py @@ -69,7 +69,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_dict, # dict格式 + "arguments": arguments_dict, # dict } } ], @@ -100,7 +100,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_str, # str格式 + "arguments": arguments_str, # str } } ], @@ -212,3 +212,294 @@ class TestQwen3IdenticalConversationArgs: decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) assert "2025-08-01" in decoded, "String time value should be present" assert "1690876800" in decoded, "Number time value should be present" + + +class TestQwen3IdenticalToolsParameters: + """ + Test Qwen3 tools parameters handling is identical between JSON string and dict + """ + + @pytest.fixture(name="tools_dict_params_dataset") + def fixture_tools_dict_params_dataset(self): + """ + Provides a dataset with tools where parameters is a dict. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72°F and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_str_params_dataset") + def fixture_tools_str_params_dataset(self): + """ + Provides a dataset with tools where parameters is a JSON string. + """ + parameters_dict = { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + } + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": json.dumps(parameters_dict), + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72°F and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_mixed_type_params_dataset") + def fixture_tools_mixed_type_params_dataset(self): + """ + Provides a dataset where different tools have the same parameter name with different types. + This tests that JSON string format prevents casting issues. + """ + tools = [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "description": "Tool expecting string argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "string", + "description": "A string parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "description": "Tool expecting number argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "number", + "description": "A numeric parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "Use both tools"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "arguments": json.dumps({"arg1": "hello"}), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "arguments": json.dumps({"arg1": 42}), + }, + }, + ], + }, + ], + } + ] + return Dataset.from_list(data) + + def test_dict_and_str_params_produce_equivalent_output( + self, + tools_dict_params_dataset, + tools_str_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that after tokenization and decoding, the outputs for both + dict and string `parameters` in tools are semantically equivalent. + """ + import re + + processed_dict_params = tools_dict_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + processed_str_params = tools_str_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0]["input_ids"]) + decoded_str = qwen3_tokenizer.decode(processed_str_params[0]["input_ids"]) + + # Extract the tool JSON from both outputs + tools_pattern = r"\n(.*?)\n" + + dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL) + str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL) + + assert dict_tools_match and str_tools_match, ( + "Could not find tools section in output" + ) + + # Parse the JSON and compare as objects (order-independent) + dict_tools_json = json.loads(dict_tools_match.group(1)) + str_tools_json = json.loads(str_tools_match.group(1)) + + # Deep comparison of the tool definitions + assert dict_tools_json == str_tools_json, ( + f"Tool definitions are not equivalent:\n" + f"Dict format: {json.dumps(dict_tools_json, indent=2)}\n" + f"String format: {json.dumps(str_tools_json, indent=2)}" + ) + + # Verify the rest of the structure is the same (excluding the tools JSON part) + # The tools JSON can have different order, so we remove it here. + dict_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_dict, + flags=re.DOTALL, + ) + str_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_str, + flags=re.DOTALL, + ) + + assert dict_normalized == str_normalized, ( + "The overall structure differs between dict and string parameter formats" + ) + + def test_str_params_with_mixed_types_no_error( + self, + tools_mixed_type_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that when different tools have the same parameter name with different types, + JSON string format for parameters doesn't cause casting errors. + """ + processed = tools_mixed_type_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + assert len(processed) == 1 + assert "input_ids" in processed[0] + assert len(processed[0]["input_ids"]) > 0 + + decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) + + # Check that both tools are present + assert "tool_with_string_arg" in decoded + assert "tool_with_number_arg" in decoded + + # Check that both argument values are present + assert "hello" in decoded + assert "42" in decoded From dd78f2e0cc5cc6458daaad02cc29b649ff1046f5 Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Mon, 10 Nov 2025 22:32:06 -0500 Subject: [PATCH 02/14] Fix: `warmup_steps: 0` & `warmup_ratio: 0` not disabling warmup (#3254) * fix unintentional falsy checks * chore: lint --------- Co-authored-by: NanoCode012 --- src/axolotl/core/builders/base.py | 4 ++-- tests/e2e/integrations/test_liger.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 2c949f8e7..7954e1fbd 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -196,9 +196,9 @@ class TrainerBuilderBase(abc.ABC): ): warmup_steps = 0 warmup_ratio = 0.0 - if self.cfg.warmup_steps: + if self.cfg.warmup_steps is not None: warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio: + elif self.cfg.warmup_ratio is not None: if total_num_steps: warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) else: diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 55317151e..e50483e6c 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -3,6 +3,7 @@ Simple end-to-end test for Liger integration """ import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config From 9901ee56028c05a3897b2b0b1bbe893f25611654 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 13 Nov 2025 22:18:42 +0700 Subject: [PATCH 03/14] fix: voxtralprocessor broken (#3255) [skip ci] * fix: voxtralprocessor broken * chore: add todo * chore: wording --- docs/multimodal.qmd | 2 ++ examples/voxtral/voxtral-mini-audio-qlora.yml | 2 +- src/axolotl/loaders/processor.py | 27 +++++++++++++++---- .../utils/mistral/mistral3_processor.py | 1 + 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 1c4e28ea7..e63a553b2 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -124,6 +124,8 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral ```yaml base_model: mistralai/Voxtral-Mini-3B-2507 + +processor_type: VoxtralProcessor ``` ### Gemma-3 {#sec-gemma-3} diff --git a/examples/voxtral/voxtral-mini-audio-qlora.yml b/examples/voxtral/voxtral-mini-audio-qlora.yml index 8fe6adbff..59150c4ca 100644 --- a/examples/voxtral/voxtral-mini-audio-qlora.yml +++ b/examples/voxtral/voxtral-mini-audio-qlora.yml @@ -1,5 +1,5 @@ base_model: mistralai/Voxtral-Mini-3B-2507 -processor_type: AutoProcessor +processor_type: VoxtralProcessor # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 7580b2008..b35ea00fd 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -1,7 +1,5 @@ """Processor loading functionality for multi-modal models""" -from typing import Any - import transformers from transformers import ( AutoProcessor, @@ -15,13 +13,33 @@ LOG = get_logger(__name__) def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): - processor_kwargs: dict[str, Any] = {} # Do we actually need this? - processor_cls = AutoProcessor if cfg.processor_type: processor_cls = getattr(transformers, cfg.processor_type) if cfg.tokenizer_use_mistral_common: + + def _patch_mistralcommontokenizer(): + """ + Transformers v5 stops reading the sub-processor. + + We need to patch this, so both processors use this. + """ + import transformers.tokenization_mistral_common as tokenization_mistral_common + + from axolotl.utils.mistral import HFMistralTokenizer + + tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer + + _patch_mistralcommontokenizer() + + from transformers import VoxtralProcessor + + if processor_cls == VoxtralProcessor: + return VoxtralProcessor.from_pretrained( + cfg.processor_config, + ) + from axolotl.utils.mistral import Mistral3Processor return Mistral3Processor( @@ -32,7 +50,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): cfg.processor_config, trust_remote_code=cfg.trust_remote_code or False, tokenizer=tokenizer, - **processor_kwargs, ) # Attempt to load image size from processor if available diff --git a/src/axolotl/utils/mistral/mistral3_processor.py b/src/axolotl/utils/mistral/mistral3_processor.py index 85479ca7b..01e8f9f10 100644 --- a/src/axolotl/utils/mistral/mistral3_processor.py +++ b/src/axolotl/utils/mistral/mistral3_processor.py @@ -30,6 +30,7 @@ class Mistral3Processor(ProcessorMixin): Wraps HFMistralTokenizer and adds image processing capabilities. """ + # TODO(nano): This should be removed in transformers V5 attributes = ["tokenizer"] tokenizer_class = "HFMistralTokenizer" From 49b81079891020c20477c8f9b3f63cd683d16d95 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 13 Nov 2025 22:19:16 +0700 Subject: [PATCH 04/14] feat: add granite4 examples (#3256) [skip ci] --- examples/granite4/README.md | 65 +++++++++++++++++++++ examples/granite4/granite-4.0-tiny-fft.yaml | 45 ++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 examples/granite4/README.md create mode 100644 examples/granite4/granite-4.0-tiny-fft.yaml diff --git a/examples/granite4/README.md b/examples/granite4/README.md new file mode 100644 index 000000000..d5efd3349 --- /dev/null +++ b/examples/granite4/README.md @@ -0,0 +1,65 @@ +# Finetune IBM's Granite 4.0 with Axolotl + +[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.7.1 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/granite4/granite-4.0-tiny-fft.yaml +``` + +This config uses about 40.8GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +### Limitation + +Adapter finetuning does not work at the moment. It would error with + +```bash +RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648) +``` + +In addition, if adapter training works, `lora_target_linear: true` will not work due to: +```bash +ValueError: Target module GraniteMoeHybridParallelExperts() is not supported. +``` + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Granite Docs](https://www.ibm.com/granite/docs/models/granite) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/granite4/granite-4.0-tiny-fft.yaml b/examples/granite4/granite-4.0-tiny-fft.yaml new file mode 100644 index 000000000..7ff8207ae --- /dev/null +++ b/examples/granite4/granite-4.0-tiny-fft.yaml @@ -0,0 +1,45 @@ +base_model: ibm-granite/granite-4.0-tiny-preview + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config From dcf24fd24ed59993e03cde0fc17e464a542bf52e Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:51:05 +0530 Subject: [PATCH 05/14] feat: save checkpoint after training started (#3233) * add:config parameters for checkpoint * callback main * test file_type fix * lint * unit * simplify dict/obj handeling * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Delete tests/e2e/integrations/__init__.py * remove hard code path in test * device check * lint * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: NanoCode012 * lint-2 * remove: singal based checkpoints * lint * remove signal tests * add:is_main_process * lint * addis_d:istributed() for tests * remove nested is_main_process * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian * add user_defined_filename --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/core/builders/base.py | 7 + .../utils/callbacks/dynamic_checkpoint.py | 132 ++++++ src/axolotl/utils/schemas/config.py | 8 + .../utils/schemas/dynamic_checkpoint.py | 31 ++ tests/e2e/integrations/__init__.py | 0 .../callbacks/test_dynamic_checkpoint.py | 389 ++++++++++++++++++ 6 files changed, 567 insertions(+) create mode 100644 src/axolotl/utils/callbacks/dynamic_checkpoint.py create mode 100644 src/axolotl/utils/schemas/dynamic_checkpoint.py delete mode 100644 tests/e2e/integrations/__init__.py create mode 100644 tests/utils/callbacks/test_dynamic_checkpoint.py diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 7954e1fbd..fc6759ffb 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -118,6 +118,13 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) + if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled: + from axolotl.utils.callbacks.dynamic_checkpoint import ( + DynamicCheckpointCallback, + ) + + callbacks.append(DynamicCheckpointCallback(self.cfg)) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) diff --git a/src/axolotl/utils/callbacks/dynamic_checkpoint.py b/src/axolotl/utils/callbacks/dynamic_checkpoint.py new file mode 100644 index 000000000..632109225 --- /dev/null +++ b/src/axolotl/utils/callbacks/dynamic_checkpoint.py @@ -0,0 +1,132 @@ +from pathlib import Path + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.distributed import ( + barrier, + is_distributed, + is_main_process, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save" + + +class DynamicCheckpointCallback(TrainerCallback): + """ + Callback to save checkpoints on-demand during training via: + 1. File-based trigger (works everywhere, rank 0 checks file) + + Thread-safe for multi-GPU distributed training. + + Usage: + # File-based: + touch /path/to/output_dir/axolotl_checkpoint.save + """ + + def _get_config_value(self, config, key, default=None): + """Helper to get config value from dict or object.""" + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + def __init__(self, cfg): + self.cfg = cfg + if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled: + self.enabled = False + return + + self.enabled = True + dc_config = cfg.dynamic_checkpoint + + trigger_file_path = self._get_config_value(dc_config, "trigger_file_path") + self.trigger_filename = ( + trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME + ) + + check_interval = self._get_config_value(dc_config, "check_interval") + self.check_interval = check_interval if check_interval is not None else 100 + self.should_save_checkpoint = False + + LOG.info( + f"Dynamic checkpoint enabled. To trigger checkpoint save:\n" + f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n" + f" • Check interval: every {self.check_interval} steps", + main_process_only=True, + ) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + """ + Check for checkpoint triggers at the end of each step. + ONLY rank 0 checks the file, then all ranks synchronize. + """ + if not self.enabled: + return control + + trigger_detected = False + + if state.global_step % self.check_interval == 0: + if is_main_process(): + trigger_path = Path(args.output_dir) / self.trigger_filename + + if trigger_path.exists(): + trigger_detected = True + try: + trigger_path.unlink() # Delete the trigger file + LOG.info( + f"Dynamic checkpoint triggered via file '{self.trigger_filename}' " + f"at step {state.global_step}", + main_process_only=True, + ) + except OSError as exc: + LOG.warning( + f"Failed to delete trigger file: {exc}", + main_process_only=True, + ) + + if self.should_save_checkpoint: + trigger_detected = True + self.should_save_checkpoint = False # Reset flag + + if is_distributed(): + import torch + import torch.distributed as dist + + device = getattr( + args, + "device", + torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + trigger_tensor = torch.tensor( + 1 if trigger_detected else 0, + dtype=torch.long, + device=device, + ) + + dist.broadcast(trigger_tensor, src=0) + + trigger_detected = bool(trigger_tensor.item()) + + barrier() + + if trigger_detected: + control.should_save = True + LOG.info( + f"Saving dynamic checkpoint at step {state.global_step}", + main_process_only=True, + ) + return control diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 86b3aa17b..5ad55f8b7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters +from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( @@ -141,6 +142,13 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Reward modelling: `True` or `False`"}, ) + dynamic_checkpoint: DynamicCheckpointConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration for dynamic checkpointing (trigger by file or signal). " + "Set 'enabled: true' to activate this feature." + }, + ) process_reward_model: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/dynamic_checkpoint.py b/src/axolotl/utils/schemas/dynamic_checkpoint.py new file mode 100644 index 000000000..e0e1d0c1d --- /dev/null +++ b/src/axolotl/utils/schemas/dynamic_checkpoint.py @@ -0,0 +1,31 @@ +"""Schema for dynamic checkpoint configuration.""" + +from pydantic import BaseModel, Field + + +class DynamicCheckpointConfig(BaseModel): + """Configuration for dynamic checkpoint triggering during training.""" + + enabled: bool = Field( + default=False, + json_schema_extra={ + "description": "Enable dynamic checkpoint triggering during training. " + "Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. " + }, + ) + check_interval: int = Field( + default=10, + ge=1, + json_schema_extra={ + "description": "Check for trigger file every N steps (reduces I/O overhead). " + "Default: 100" + }, + ) + trigger_file_path: str = Field( + default="", + json_schema_extra={ + "description": "Custom trigger filename (optional). " + "If not specified, defaults to 'axolotl_checkpoint.save'. " + "Specify a filename (not a full path) to override the default." + }, + ) diff --git a/tests/e2e/integrations/__init__.py b/tests/e2e/integrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/utils/callbacks/test_dynamic_checkpoint.py b/tests/utils/callbacks/test_dynamic_checkpoint.py new file mode 100644 index 000000000..1fd792102 --- /dev/null +++ b/tests/utils/callbacks/test_dynamic_checkpoint.py @@ -0,0 +1,389 @@ +"""Unit tests for dynamic checkpoint callback""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +from axolotl.utils.callbacks.dynamic_checkpoint import ( + DEFAULT_TRIGGER_FILENAME, + DynamicCheckpointCallback, +) +from axolotl.utils.dict import DictDefault + + +class TestDynamicCheckpointCallbackInit: + """Test callback initialization""" + + def test_callback_disabled_by_default(self): + """Test that callback is disabled when config.enabled=False""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_disabled_when_none(self): + """Test that callback is disabled when dynamic_checkpoint is None""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": None, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_enabled_when_configured(self): + """Test that callback is enabled when config.enabled=True""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is True + assert callback.check_interval == 10 + + def test_default_trigger_filename(self): + """Test that default trigger filename is used""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME + + def test_check_interval_default(self): + """Test default check interval""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.check_interval == 100 # Default from schema + + +class TestDynamicCheckpointFileDetection: + """Test file-based checkpoint triggering""" + + def test_trigger_file_detected_and_deleted(self): + """Test that trigger file is detected and deleted""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + assert trigger_file.exists() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert not trigger_file.exists() + assert result.should_save is True + + def test_check_interval_honored(self): + """Test that file is only checked at check_interval steps""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + control = Mock(should_save=False) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + # Step 5 - shouldn't check (not divisible by 10) + state = Mock(global_step=5) + result = callback.on_step_end(args, state, control) + assert trigger_file.exists() # Still there + assert result.should_save is False + + # Step 10 - should check + state = Mock(global_step=10) + result = callback.on_step_end(args, state, control) + assert not trigger_file.exists() # Deleted + assert result.should_save is True + + def test_no_file_no_trigger(self): + """Test that no trigger occurs when file doesn't exist""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is False + + def test_file_deletion_error_handling(self): + """Test that file deletion errors are handled gracefully""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + with patch.object( + Path, "unlink", side_effect=OSError("Permission denied") + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + + +class TestDynamicCheckpointMultiGPU: + """Test multi-GPU synchronization""" + + def test_only_rank_0_checks_file(self): + """Test that only rank 0 checks filesystem in multi-GPU setup""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 1 (not main process) - shouldn't check file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=False, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ): + mock_tensor = MagicMock() + mock_tensor.item.return_value = 0 + with patch("torch.tensor", return_value=mock_tensor): + callback.on_step_end(args, state, control) + + assert trigger_file.exists() + # Broadcast should have been called + assert mock_broadcast.called + + def test_broadcast_synchronization(self): + """Test that trigger decision is broadcasted to all ranks""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 0 detects file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ) as mock_barrier: + mock_tensor = MagicMock() + mock_tensor.item.return_value = 1 + with patch("torch.tensor", return_value=mock_tensor): + with patch("torch.cuda.current_device", return_value=0): + result = callback.on_step_end(args, state, control) + + assert mock_broadcast.called + assert mock_barrier.called + # All ranks should trigger + assert result.should_save is True + + +class TestDynamicCheckpointSignalHandling: + """Test signal-based checkpoint triggering""" + + def test_signal_trigger_via_callback(self): + """Test that signal flag triggers checkpoint save""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 1, + "enable_signal": True, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal"): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.hasattr", + return_value=True, + ): + callback = DynamicCheckpointCallback(cfg) + + callback.should_save_checkpoint = True + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + assert callback.should_save_checkpoint is False + + def test_signal_not_registered_when_disabled(self): + """Test that signal handler is not registered when disabled""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 10, + "enable_signal": False, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal") as mock_signal_register: + _ = DynamicCheckpointCallback(cfg) + + assert not mock_signal_register.called + + +class TestDynamicCheckpointDisabled: + """Test behavior when callback is disabled""" + + def test_disabled_callback_does_nothing(self): + """Test that disabled callback doesn't check or trigger""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + result = callback.on_step_end(args, state, control) + + assert trigger_file.exists() + assert result.should_save is False From 301e22849f41c67c31e065a222235ab120fd4074 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 13 Nov 2025 13:03:01 -0500 Subject: [PATCH 06/14] upgrade to latest deepspeed and make sure latest tagged axolotl images are using torch 2.8.0 (#3261) --- .github/workflows/main.yml | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4040ccdc9..3b182af02 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,7 +25,6 @@ jobs: python_version: "3.11" pytorch: 2.7.1 axolotl_extras: vllm - is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" @@ -36,6 +35,7 @@ jobs: python_version: "3.11" pytorch: 2.8.0 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -99,7 +99,6 @@ jobs: python_version: "3.11" pytorch: 2.7.1 axolotl_extras: vllm - is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" @@ -110,6 +109,7 @@ jobs: python_version: "3.11" pytorch: 2.8.0 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/setup.py b/setup.py index 9c1161642..a1bdd6bdf 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,7 @@ extras_require = { "ring-flash-attn>=0.1.7", ], "deepspeed": [ - "deepspeed==0.17.5", + "deepspeed==0.18.2", "deepspeed-kernels", ], "mamba-ssm": [ From 0fbde69e9c21133f668fd1bbf1e2d59203c546a7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 14 Nov 2025 10:50:03 -0500 Subject: [PATCH 07/14] only push axolotl images, personal repo is deprecated (#3262) * only push axolotl images, personal repo is deprecated * cleanup --- .github/FUNDING.yml | 6 +++--- .github/workflows/base.yml | 1 - .github/workflows/main.yml | 3 --- .github/workflows/nightlies.yml | 2 -- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 4f6ea8de7..cd443a197 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,13 +1,13 @@ # These are supported funding model platforms -github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username -ko_fi: axolotl_ai # Replace with a single Ko-fi username +ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry -custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480¢erImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 87d6772dd..2e8950dd9 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -90,7 +90,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-base axolotlai/axolotl-base - name: Login to Docker Hub uses: docker/login-action@v2 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3b182af02..4f0cc4c99 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -45,7 +45,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl axolotlai/axolotl tags: | type=ref,event=branch @@ -119,7 +118,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud axolotlai/axolotl-cloud tags: | type=ref,event=branch @@ -179,7 +177,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud-term axolotlai/axolotl-cloud-term tags: | type=ref,event=branch diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 18b036a0d..a24946ae9 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -31,7 +31,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl axolotlai/axolotl tags: | type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} @@ -84,7 +83,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud axolotlai/axolotl-cloud tags: | type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} From a6bafb55cbd6973222d6d0d37aee2013ae656ba7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 14 Nov 2025 12:52:14 -0500 Subject: [PATCH 08/14] upgrade datasets to 4.4.1 (#3266) * upgrade datasets * cleanup pip cache earlier * cleanup unused things from worker * also cleanup sdist --- .github/workflows/tests.yml | 24 ++++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7ad9d1ab4..95370ca3d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,6 +59,10 @@ jobs: timeout-minutes: 20 steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + - name: Check out repository code uses: actions/checkout@v4 @@ -91,6 +95,10 @@ jobs: python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + - name: Make sure PyTorch version wasn't clobbered run: | python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" @@ -118,10 +126,6 @@ jobs: flags: unittests,pytorch-${{ matrix.pytorch_version }} fail_ci_if_error: false - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - pytest-sdist: name: PyTest from Source Dist runs-on: ubuntu-latest @@ -134,6 +138,10 @@ jobs: timeout-minutes: 20 steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + - name: Check out repository code uses: actions/checkout@v4 @@ -167,6 +175,10 @@ jobs: python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + - name: Make sure PyTorch version wasn't clobbered run: | python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" @@ -184,10 +196,6 @@ jobs: pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - gate-skip-e2e: needs: [pre-commit, pytest, pytest-sdist] runs-on: ubuntu-latest diff --git a/requirements.txt b/requirements.txt index a12a3941b..62c1b3cba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ peft>=0.17.1 tokenizers>=0.22.1 transformers==4.57.1 accelerate==1.11.0 -datasets==4.3.0 +datasets==4.4.1 deepspeed>=0.17.0 trl==0.25.0 hf_xet==1.2.0 From 4e558711120c18175ab8689da017e28ec7d0a6cb Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 18 Nov 2025 11:35:25 +0700 Subject: [PATCH 09/14] feat: Add opt-out Telemetry (#3237) * initial telemetry manager impl * adding todo * updates * updates * progress on telemetry: config load, process, model load, train start / end, error tracking * update error file path sanitization function; adding more error tracking * updated sanitization logic, tests * adding runtime metrics (cpu + gpu memory, steps/s, etc.) * tests for runtime metrics telemetry and assoc. callback * small update / fix * simplifying path redaction * sleep on all ranks in distributed setting * adding back in base_model redaction w/ whitelist * fix * doc update * improved redaction, send system info during model config load telemetry, etc. * adding runtime metrics / system info additional accelerator support, etc. * adding runtime metrics / system info additional accelerator support, etc. * remove duplicate info * fixes * fix issue with tests in ci * distributed fix * opt-in version of telemetry * enable / disable logic update * docs fix * doc update * minor fixes * simplifying * slight changes * fix * lint * update posthog dep * coderabbit comments * fix: opt-in model * fix: increase time since last * fix: increase whitelist orgs * fix: posthog init and shutdown * fix: imports * fix: also check grad norm * fix: duplicate plugin_manager calls * fix: bad merge * chore: update docs * fix: cache process per comment * fix: error handling * fix: tests * Revert "fix: error handling" This reverts commit 22d1ea5755500a7c0e08c562270ca85591bf451f. * fix: test telemetry error_handled bool * fix: revert test * chore: final doc fixes --------- Co-authored-by: Dan Saunders Co-authored-by: Dan Saunders --- README.md | 7 + _quarto.yml | 1 + docs/telemetry.qmd | 61 +++ requirements.txt | 3 + src/axolotl/cli/config.py | 8 + src/axolotl/cli/inference.py | 7 +- src/axolotl/cli/merge_lora.py | 2 + src/axolotl/cli/merge_sharded_fsdp_weights.py | 2 + src/axolotl/cli/preprocess.py | 2 + src/axolotl/common/datasets.py | 3 + src/axolotl/core/builders/base.py | 6 + src/axolotl/evaluate.py | 2 + src/axolotl/loaders/adapter.py | 2 + src/axolotl/loaders/model.py | 2 + src/axolotl/loaders/processor.py | 2 + src/axolotl/loaders/tokenizer.py | 2 + src/axolotl/telemetry/__init__.py | 0 src/axolotl/telemetry/callbacks.py | 165 +++++++ src/axolotl/telemetry/errors.py | 160 +++++++ src/axolotl/telemetry/manager.py | 416 ++++++++++++++++++ src/axolotl/telemetry/runtime_metrics.py | 210 +++++++++ src/axolotl/telemetry/whitelist.yaml | 33 ++ src/axolotl/train.py | 27 +- src/axolotl/utils/schemas/config.py | 2 +- tests/conftest.py | 10 +- tests/telemetry/__init__.py | 0 tests/telemetry/conftest.py | 9 + tests/telemetry/test_callbacks.py | 373 ++++++++++++++++ tests/telemetry/test_errors.py | 341 ++++++++++++++ tests/telemetry/test_manager.py | 275 ++++++++++++ tests/telemetry/test_runtime_metrics.py | 357 +++++++++++++++ 31 files changed, 2479 insertions(+), 11 deletions(-) create mode 100644 docs/telemetry.qmd create mode 100644 src/axolotl/telemetry/__init__.py create mode 100644 src/axolotl/telemetry/callbacks.py create mode 100644 src/axolotl/telemetry/errors.py create mode 100644 src/axolotl/telemetry/manager.py create mode 100644 src/axolotl/telemetry/runtime_metrics.py create mode 100644 src/axolotl/telemetry/whitelist.yaml create mode 100644 tests/telemetry/__init__.py create mode 100644 tests/telemetry/conftest.py create mode 100644 tests/telemetry/test_callbacks.py create mode 100644 tests/telemetry/test_errors.py create mode 100644 tests/telemetry/test_manager.py create mode 100644 tests/telemetry/test_runtime_metrics.py diff --git a/README.md b/README.md index 6313a73ca..d6dd67988 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details. +## 📈 Telemetry + +Axolotl has opt-out telemetry that helps us understand how the project is being used +and prioritize improvements. We collect basic system information, model types, and +error rates—never personal data or file paths. Telemetry is enabled by default. To +disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html). + ## ❤️ Sponsors Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai) diff --git a/_quarto.yml b/_quarto.yml index fad3f6786..c97b9838e 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -241,6 +241,7 @@ website: - docs/installation.qmd - docs/inference.qmd - docs/cli.qmd + - docs/telemetry.qmd - docs/config-reference.qmd - text: "API Reference" href: docs/api diff --git a/docs/telemetry.qmd b/docs/telemetry.qmd new file mode 100644 index 000000000..62d7c9bbc --- /dev/null +++ b/docs/telemetry.qmd @@ -0,0 +1,61 @@ +--- +title: Telemetry +description: A description of the telemetry implementation in Axolotl. +--- + +# Telemetry in Axolotl + +Axolotl implements anonymous telemetry to help maintainers understand how the library +is used and where users encounter issues. This data helps prioritize features, optimize +performance, and fix bugs. + +## Data Collection + +We collect: + +- System info: OS, Python version, Axolotl version, PyTorch version, Transformers +version, etc. +- Hardware info: CPU count, memory, GPU count and models +- Runtime metrics: Training progress, memory usage, timing information +- Usage patterns: Models (from a whitelist) and configurations used +- Error tracking: Stack traces and error messages (sanitized to remove personal +information) + +Personally identifiable information (PII) is not collected. + +## Implementation + +Telemetry is implemented using PostHog and consists of: + +- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the +telemetry system and provides methods for tracking events. +- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and +sends sanitized stack traces. +- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks +runtime metrics during training. +- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends +runtime metrics telemetry. + +The telemetry system will block training startup for 10 seconds to ensure users are +aware of data collection, unless telemetry is explicitly enabled or disabled. + +## Opt-Out Mechanism + +Telemetry is **enabled by default** on an opt-out basis. To disable it, set +`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`. + +A warning message will be logged on start to clearly inform users about telemetry. +We will remove this after some period. + +To hide the warning message about telemetry that is displayed on train, etc. startup, +explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1` +(explicitly disable telemetry). + +## Privacy + +- All path-like config information is automatically redacted from telemetry data +- Model information is only collected for whitelisted organizations + - See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations +- Each run generates a unique anonymous ID + - This allows us to link different telemetry events in a single same training run +- Telemetry is only sent from the main process to avoid duplicate events diff --git a/requirements.txt b/requirements.txt index 62c1b3cba..977262df5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,7 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.7 axolotl-contribs-mit==0.0.5 +# telemetry +posthog==6.7.11 + mistral-common==1.8.5 diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 93ac6147d..3c4ace7b0 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -14,6 +14,8 @@ import yaml from transformers.utils import is_torch_bf16_gpu_available from axolotl.integrations.base import PluginManager +from axolotl.telemetry.errors import send_errors +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -31,6 +33,8 @@ LOG = get_logger(__name__) API_KEY_FIELDS = {"comet_api_key"} +TELEMETRY_MANAGER = TelemetryManager.get_instance() + def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: """ @@ -164,6 +168,7 @@ def plugin_set_cfg(cfg: DictDefault): plugin_manager.cfg = cfg +@send_errors def load_cfg( config: str | Path | DictDefault = Path("examples/"), **kwargs ) -> DictDefault: @@ -197,6 +202,8 @@ def load_cfg( temp_file.close() cfg.axolotl_config_path = temp_file.name + TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg) + # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value cfg_keys = cfg.keys() @@ -240,6 +247,7 @@ def load_cfg( setup_comet_env_vars(cfg) plugin_set_cfg(cfg) + TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) cfg_to_log = { k: "[REDACTED]" if k in API_KEY_FIELDS else v for k, v in cfg.items() diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 3e1c01520..640be3696 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -19,7 +19,10 @@ from axolotl.cli.utils.diffusion import ( launch_diffusion_gradio_ui, ) from axolotl.integrations.base import PluginManager -from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.telemetry.errors import send_errors +from axolotl.utils.chat_templates import ( + get_chat_template_from_config, +) from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -43,6 +46,7 @@ def get_multi_line_input() -> str: return instruction +@send_errors def do_inference( *, cfg: DictDefault, @@ -160,6 +164,7 @@ def do_inference( print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) +@send_errors def do_inference_gradio( *, cfg: DictDefault, diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 657ddcfe4..482767b12 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -7,12 +7,14 @@ import fire from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +@send_errors def do_merge_lora(*, cfg: DictDefault) -> None: """ Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 43142d79e..1d9736b9d 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -23,6 +23,7 @@ from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from axolotl.cli.config import load_cfg +from axolotl.telemetry.errors import send_errors from axolotl.utils.logging import get_logger from axolotl.utils.train import determine_last_checkpoint @@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights( return save_path_ +@send_errors def merge_fsdp_weights( checkpoint_dir: str, output_path: str, diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 6c05a55f1..af35dd801 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.trainer import disable_datasets_caching @@ -24,6 +25,7 @@ from axolotl.utils.trainer import disable_datasets_caching LOG = get_logger(__name__) +@send_errors def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: """ Preprocesses dataset specified in axolotl config. diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 8d7758e66..c95ddb80e 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -9,6 +9,7 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.loaders import load_processor, load_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -34,6 +35,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: ) +@send_errors def load_datasets( *, cfg: DictDefault, @@ -96,6 +98,7 @@ def load_datasets( ) +@send_errors def load_preference_datasets( *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None ) -> TrainDatasetMeta: diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index fc6759ffb..0d19b369f 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -29,6 +29,8 @@ from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr +from axolotl.telemetry.callbacks import TelemetryCallback +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils import ( is_comet_available, is_mlflow_available, @@ -162,6 +164,10 @@ class TrainerBuilderBase(abc.ABC): ) ) + telemetry_manager = TelemetryManager.get_instance() + if telemetry_manager.enabled: + callbacks.append(TelemetryCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index e4496bee6..db6fb3f16 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -10,6 +10,7 @@ import torch from datasets import Dataset from transformers.trainer import Trainer +from axolotl.telemetry.errors import send_errors from axolotl.train import ( TrainDatasetMeta, setup_model_and_tokenizer, @@ -63,6 +64,7 @@ def evaluate_dataset( return metrics +@send_errors def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets. diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index bcde4bf96..8e8177b62 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -20,6 +20,7 @@ from peft import ( from transformers import PreTrainedModel from axolotl.loaders.utils import get_linear_embedding_layers +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -172,6 +173,7 @@ def load_lora( return model, lora_config +@send_errors def load_adapter( model: PreTrainedModel, cfg: DictDefault, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index aeec46584..1eeed3565 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -49,6 +49,7 @@ from axolotl.loaders.utils import ( load_model_config, ) from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.telemetry.errors import send_errors from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( @@ -158,6 +159,7 @@ class ModelLoader: """Property that determines if FSDP with QLoRA is enabled.""" return self.is_fsdp_enabled and self.cfg.adapter == "qlora" + @send_errors def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index b35ea00fd..827b4be35 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -6,12 +6,14 @@ from transformers import ( PreTrainedTokenizerBase, ) +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +@send_errors def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): processor_cls = AutoProcessor if cfg.processor_type: diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 69455dd77..48856116c 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -13,6 +13,7 @@ from transformers import ( from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN +from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( @@ -119,6 +120,7 @@ def modify_tokenizer_files( return tokenizer_dir +@send_errors def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" diff --git a/src/axolotl/telemetry/__init__.py b/src/axolotl/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py new file mode 100644 index 000000000..0ce52ffa4 --- /dev/null +++ b/src/axolotl/telemetry/callbacks.py @@ -0,0 +1,165 @@ +"""Trainer callbacks for reporting runtime metrics at regular intervals.""" + +import logging +import time + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.telemetry.manager import TelemetryManager +from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker + +LOG = logging.getLogger(__name__) + +TIME_SINCE_LAST = 60 + + +class TelemetryCallback(TrainerCallback): + """ + Trainer callback for tracking and reporting runtime metrics. + + This callback tracks training progress, runtime, and memory usage, + sending telemetry at configurable intervals. + """ + + report_interval_steps: int = 100 + + def __init__(self): + """Initialize the metrics callback.""" + self.tracker = RuntimeMetricsTracker() + self.telemetry_manager = TelemetryManager.get_instance() + self.current_epoch = -1 + self.start_time = time.time() + self.last_report_time = None + self.last_report_step = 0 + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle training start.""" + self.telemetry_manager.send_event(event_type="train-start") + + # pylint: disable=unused-argument + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle training end.""" + # Send training completion event + self.telemetry_manager.send_event( + event_type="train-end", + properties=self._extract_last_metrics(state) + | self.tracker.metrics.to_dict(), + ) + + # pylint: disable=unused-argument + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle epoch start.""" + self.current_epoch += 1 + self.tracker.start_epoch(self.current_epoch) + + # pylint: disable=unused-argument + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle epoch end.""" + self.tracker.end_epoch(self.current_epoch) + + # pylint: disable=unused-argument + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle step end.""" + step = state.global_step + self.tracker.update_step(step) + + # Check if we should report metrics + should_report = ( + step % self.report_interval_steps == 0 + or step == 1 # Always report first step + or step - self.last_report_step >= self.report_interval_steps + ) + + if should_report: + current_time = time.time() + if self.last_report_time is not None: + time_since_last_report = current_time - self.last_report_time + else: + time_since_last_report = current_time - self.start_time + steps_since_last_report = step - self.last_report_step + + # Only report if enough time has passed + if ( + step == 1 + or time_since_last_report >= TIME_SINCE_LAST + or steps_since_last_report >= self.report_interval_steps + ): + # Calculate steps per second for this interval + if time_since_last_report > 0 and steps_since_last_report > 0: + steps_per_second = steps_since_last_report / time_since_last_report + else: + steps_per_second = 0 + + # Update memory metrics + self.tracker.update_memory_metrics() + + # Prepare metrics to report + metrics = self._extract_last_metrics(state) | { + "step": step, + "epoch": self.current_epoch, + "progress": state.epoch, # Fractional epoch progress + "steps_per_second": steps_per_second, + "elapsed_time": current_time - self.start_time, + "time_since_last_report": time_since_last_report, + } + + # Add memory metrics + memory_metrics = self.tracker.get_memory_metrics() + metrics.update({"memory": memory_metrics}) + + # Send telemetry + self.telemetry_manager.send_event( + event_type="train-progress", properties=metrics + ) + + # Update last report time and step + self.last_report_time = current_time + self.last_report_step = step + + def _extract_last_metrics(self, state: TrainerState) -> dict: + """Extract last loss, learning_rate, and grad_norm from log history.""" + if not state.log_history: + return {"loss": 0, "learning_rate": 0, "grad_norm": 0} + + last_log = state.log_history[-1] + return { + "loss": last_log.get("loss", 0), + "learning_rate": last_log.get("learning_rate", 0), + "grad_norm": last_log.get("grad_norm", 0), + } diff --git a/src/axolotl/telemetry/errors.py b/src/axolotl/telemetry/errors.py new file mode 100644 index 000000000..27f2d2192 --- /dev/null +++ b/src/axolotl/telemetry/errors.py @@ -0,0 +1,160 @@ +"""Telemetry utilities for exception and traceback information.""" + +import logging +import os +import re +import traceback +from functools import wraps +from inspect import getmodule +from typing import Any, Callable + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + +ERROR_HANDLED = False + + +def sanitize_stack_trace(stack_trace: str) -> str: + """ + Remove personal information from stack trace messages while keeping Python package codepaths. + + This function identifies Python packages by looking for common patterns in virtual environment + and site-packages directories, preserving the package path while removing user-specific paths. + + Args: + stack_trace: The original stack trace string. + + Returns: + A sanitized version of the stack trace with Python package paths preserved. + """ + # Split the stack trace into lines to process each file path separately + lines = stack_trace.split("\n") + sanitized_lines = [] + + # Regular expression to find file paths in the stack trace + path_pattern = re.compile(r'(?:File ")(.*?)(?:")') + + # Regular expression to identify paths in site-packages or dist-packages + # This matches path segments like "site-packages/package_name" or "dist-packages/package_name" + site_packages_pattern = re.compile( + r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + + # Additional common virtual environment patterns + venv_lib_pattern = re.compile( + r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + + for line in lines: + # Check if this line contains a file path + path_match = path_pattern.search(line) + + if path_match: + full_path = path_match.group(1) + sanitized_path = "" + + # Try to match site-packages pattern + site_packages_match = site_packages_pattern.search(full_path) + venv_lib_match = venv_lib_pattern.search(full_path) + + if site_packages_match: + # Find the index where the matched pattern starts + idx = full_path.find("site-packages") + if idx == -1: + idx = full_path.find("dist-packages") + + # Keep from 'site-packages' onward + if idx >= 0: + sanitized_path = full_path[idx:] + elif venv_lib_match: + # For other virtual environment patterns, find the package directory + match_idx = venv_lib_match.start(1) + if match_idx > 0: + # Keep from the package name onward + package_name = venv_lib_match.group(1) + idx = full_path.rfind( + package_name, 0, match_idx + len(package_name) + ) + if idx >= 0: + sanitized_path = full_path[idx:] + + # If we couldn't identify a package pattern but path contains 'axolotl' + elif "axolotl" in full_path: + idx = full_path.rfind("axolotl") + if idx >= 0: + sanitized_path = full_path[idx:] + + # Apply the sanitization to the line + if sanitized_path: + line = line.replace(full_path, sanitized_path) + else: + # If we couldn't identify a package pattern, just keep the filename + filename = os.path.basename(full_path) + if filename: + line = line.replace(full_path, filename) + else: + line = line.replace(full_path, "") + + sanitized_lines.append(line) + + return "\n".join(sanitized_lines) + + +def send_errors(func: Callable) -> Callable: + """ + Decorator to send exception info in a function. If an exception is raised, we send + telemetry containing the stack trace and error message. + + If an error occurs in a decorated function that is called by another decorated + function, we'll only send telemetry corresponding to the lower-level function. + + Args: + func: Function to decorate. + + Returns: + Decorated function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + telemetry_manager = TelemetryManager.get_instance() + + if not telemetry_manager.enabled: + return func(*args, **kwargs) + + try: + return func(*args, **kwargs) + except Exception as exception: + # Only track if we're not already handling an error. This prevents us from + # capturing an error more than once in nested decorated function calls. + global ERROR_HANDLED # pylint: disable=global-statement + if not ERROR_HANDLED: + ERROR_HANDLED = True + + # Get function module path + module = getmodule(func) + module_path = ( + f"{module.__name__}.{func.__name__}" if module else func.__name__ + ) + + # Get stack trace + stack_trace = "".join( + traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) + ) + stack_trace = sanitize_stack_trace(stack_trace) + + # Send error telemetry + telemetry_manager.send_event( + event_type=f"{module_path}-error", + properties={ + "exception": str(exception), + "stack_trace": stack_trace, + }, + ) + + raise + + return wrapper diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py new file mode 100644 index 000000000..82d310cdc --- /dev/null +++ b/src/axolotl/telemetry/manager.py @@ -0,0 +1,416 @@ +"""Telemetry manager and associated utilities.""" + +import atexit +import importlib +import logging +import os +import platform +import time +import uuid +from pathlib import Path +from typing import Any + +import posthog +import psutil +import torch +import yaml + +LOG = logging.getLogger(__name__) + +POSTHOG_HOST = "https://app.posthog.com" +POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" + +OPT_OUT_WARNING_SLEEP_SECONDS = 10 +OPT_OUT_WARNING = ( + "\nTelemetry is now enabled by default to help improve Axolotl. " + "If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n" + "Telemetry data helps us understand:\n" + "- Which features are most used\n" + "- What hardware configurations to prioritize\n" + "- Where users encounter errors\n\n" + "Personally identifiable information (PII) is not collected.\n\n" + "To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) " + "or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n" + "For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n" + f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..." +) + +WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") + +# NOTE: Need to keep these up to date with any config schema changes +FIELDS_TO_REDACT = { + "base_model", + "tokenizer_config", + "base_model_config", + "pretraining_dataset", # NOTE: this field may be a string or a dictionary + "resume_from_checkpoint", + "hub_model_id", +} +PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"} +PATH_INDICATORS = {"path", "dir"} + +# pylint: disable=duplicate-code +RELEVANT_PACKAGES = { + "torch", + "transformers", + "trl", + "datasets", + "peft", + "bitsandbytes", + "accelerate", + "optimum", + "deepspeed", + "ray", + "axolotl", + "triton", + "mamba-ssm", + "flash-attn", + "xformers", + "autoawq", + "tokenizers", + "sentencepiece", + "torchao", + "lm_eval", +} + + +def is_main_process() -> bool: + """ + Check whether we're running in the main process. + + Note: + We're using this function instead of `torch.utils.distributed.is_main_process` + causes issues with DeepSpeed world_size since. This function avoids that issue + by checking env vars that are set by various launchers. + + Returns: + Whether we're running in the main process. + """ + # If PyTorch distributed is already initialized, use it + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 + + # Otherwise check environment variables for global rank + # NOTE: need to verify this in SLURM / OpenMPI environments + global_rank = int( + os.environ.get( + "RANK", + os.environ.get( + "GLOBAL_RANK", + os.environ.get( + "SLURM_PROCID", + os.environ.get( + "OMPI_COMM_WORLD_RANK", + "0", + ), + ), + ), + ) + ) + + return global_rank == 0 + + +class TelemetryManager: + """Manages telemetry collection and transmission""" + + _instance = None + _initialized = False + + def __new__(cls): + """ + Telemetry manager constructor. Creates the singleton instance of this class if + it doesn't already exist. + """ + if cls._instance is None: + cls._instance = super(TelemetryManager, cls).__new__(cls) + cls._instance._initialized = False + + return cls._instance + + def __init__(self): + """Telemetry manager initializer""" + if self._initialized: + return + + self.enabled = self._check_telemetry_enabled() + + if self.enabled: + self.run_id = str(uuid.uuid4()) + self.whitelist = self._load_whitelist() + + try: + self.system_info = self._get_system_info() + except Exception as e: # pylint: disable=broad-exception-caught + LOG.warning(f"Error during system info collection: {e}") + self.system_info = None + + self._init_posthog() + + # Register shutdown method to flush posthog telemetry + atexit.register(self.shutdown) + + self._initialized = True + + @classmethod + def get_instance(cls) -> "TelemetryManager": + if cls._instance is None: + cls._instance = TelemetryManager() + + return cls._instance + + def _check_telemetry_enabled(self) -> bool: + """ + Check if telemetry is enabled based on environment variables. We also check + whether this is the main process (for the distributed setting and to avoid + sending duplicate PostHog events per GPU). + + Note: This is enabled by default on an opt-out basis. Set + `AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see + https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html. + + Returns: + Boolean denoting whether telemetry is enabled or not. + """ + # Parse relevant env vars + axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") + do_not_track = os.getenv("DO_NOT_TRACK") + + # Default to enabled (opt-out model) + if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in ( + "0", + "1", + "false", + "true", + ): + # Print opt-out info message for main process only + if is_main_process(): + LOG.warning(OPT_OUT_WARNING) + time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS) + + return True + + # Only rank 0 will send telemetry + if not is_main_process(): + return False + + if do_not_track is None: + do_not_track = "0" + + # Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled + enabled = axolotl_do_not_track.lower() not in ( + "1", + "true", + ) and do_not_track.lower() not in ("1", "true") + + return enabled + + def _load_whitelist(self) -> dict: + """Load HuggingFace Hub organization whitelist""" + with open(WHITELIST_PATH, encoding="utf-8") as f: + whitelist = yaml.safe_load(f) + + # Send org strings to lowercase since model names are case insensitive + whitelist["organizations"] = { + org.lower() for org in whitelist["organizations"] + } + + return whitelist + + def _is_whitelisted(self, value: str) -> bool: + """ + Check if model / dataset / etc. org is in whitelist. + + Args: + value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS` + ("base_model", etc.). + + Returns: + Boolean indicating whitelist membership. + """ + # NOTE: This membership-checking logic can be improved. + # What happens when a local model path matches a whitelisted org? + parts = value.split("/") + if len(parts) < 2: + return False + org = parts[0] + whitelisted = org.lower() in self.whitelist["organizations"] + + return whitelisted + + def _init_posthog(self): + """Initialize PostHog client""" + posthog.api_key = POSTHOG_WRITE_KEY + posthog.project_api_key = POSTHOG_WRITE_KEY + posthog.host = POSTHOG_HOST + + def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]: + """ + Redact properties to remove any paths, so as to avoid inadvertently collecting + private or personally identifiable information (PII). We also remove + information related to Wandb, MLflow, etc. configuration. + + Args: + properties: Dictionary of properties to redact. + + Returns: + Properties dictionary with redaction applied. + """ + if not properties: + return {} + + def redact_value(value: Any, key: str = "") -> Any: + """Recursively sanitize values, redacting those with path-like keys""" + if isinstance(key, str) and isinstance(value, str): + # Other redaction special cases + if ( + key in FIELDS_TO_REDACT + or any(prefix in key for prefix in PREFIXES_TO_REDACT) + or any(indicator in key.lower() for indicator in PATH_INDICATORS) + ): + # Fields with whitelisted orgs don't need to be redacted + if not self._is_whitelisted(value): + return "[REDACTED]" + + # Handle nested values + if isinstance(value, dict): + return {k: redact_value(v, k) for k, v in value.items()} + if isinstance(value, list): + return [redact_value(item) for item in value] + + return value + + # Create new dict with redacted values + redacted = {k: redact_value(v, k) for k, v in properties.items()} + + return redacted + + def _get_system_info(self) -> dict[str, Any]: + """Collect system information for various hardware accelerators""" + gpu_info = [] + accelerator_type = "none" + + # NVIDIA GPUs + if torch.cuda.is_available(): + accelerator_type = "cuda" + for i in range(torch.cuda.device_count()): + gpu_info.append( + { + "name": torch.cuda.get_device_name(i), + "memory": torch.cuda.get_device_properties(i).total_memory, + } + ) + + # AMD GPUs + elif hasattr(torch, "hip") and torch.hip.is_available(): + accelerator_type = "hip" + for i in range(torch.hip.device_count()): + gpu_info.append( + { + "name": torch.hip.get_device_name(i), + "memory": ( + torch.hip.get_device_properties(i).total_memory + if hasattr(torch.hip, "get_device_properties") + else None + ), + } + ) + + # Apple Silicon + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + accelerator_type = "mps" + gpu_info.append( + { + "name": "Apple Silicon", + # NOTE: this is memory allocated to this process, not total memory + "memory": torch.mps.driver_allocated_memory(), + } + ) + + # Intel GPUs + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + accelerator_type = "xpu" + for i in range(torch.xpu.device_count()): + memory = None + if hasattr(torch.xpu, "get_device_properties"): + memory = torch.xpu.get_device_properties(i).total_memory + + gpu_info.append( + { + "name": torch.xpu.get_device_name(i), + "memory": memory, + } + ) + + # NPUs + elif hasattr(torch, "npu") and torch.npu.is_available(): + accelerator_type = "npu" + for i in range(torch.npu.device_count()): + memory = None + if hasattr(torch.npu, "get_device_properties"): + memory = torch.npu.get_device_properties(i).total_memory + + gpu_info.append( + { + "name": torch.npu.get_device_name(i), + "memory": memory, + } + ) + + # Get relevant package versions + installed_packages = {} + for package in RELEVANT_PACKAGES: + try: + version = importlib.metadata.version(package) + installed_packages[f"{package}_version"] = version + except importlib.metadata.PackageNotFoundError: + pass + + return { + "os": platform.system(), + "python_version": platform.python_version(), + "cpu_count": psutil.cpu_count(), + "memory_total": psutil.virtual_memory().total, + "accelerator_type": accelerator_type, + "accelerator_count": len(gpu_info), + "accelerator_info": gpu_info, + **installed_packages, + } + + def send_event(self, event_type: str, properties: dict[str, Any] | None = None): + """Send a telemetry event""" + if not self.enabled: + return + + if properties is None: + properties = {} + + # Sanitize properties to remove PII + properties = self._redact_paths(properties) + + # Wrap PostHog errors in try / except to not raise errors during Axolotl usage + try: + # Send event via PostHog + posthog.capture( + distinct_id=self.run_id, + event=event_type, + properties=properties, + disable_geoip=True, + ) + except Exception as e: # pylint: disable=broad-exception-caught + LOG.warning(f"Failed to send telemetry event: {e}") + + # Additionally, send system info telemetry when loading config. + # NOTE: Is this the best place for this? + if event_type == "config-loaded": + self.send_system_info() + + def send_system_info(self): + """Helper method for sending system info""" + if self.system_info is not None: + self.send_event(event_type="system-info", properties=self.system_info) + + def shutdown(self): + """Ensure all queued events are processed before shutdown""" + if self.enabled: + posthog.shutdown() diff --git a/src/axolotl/telemetry/runtime_metrics.py b/src/axolotl/telemetry/runtime_metrics.py new file mode 100644 index 000000000..fa83c00a7 --- /dev/null +++ b/src/axolotl/telemetry/runtime_metrics.py @@ -0,0 +1,210 @@ +"""Telemetry utilities for runtime and memory metrics.""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +import psutil +import torch + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + + +@dataclass +class RuntimeMetrics: + """Container for runtime metrics to be tracked throughout training.""" + + # Timing metrics + start_time: float + epoch_start_times: dict[int, float] = field(init=False) + epoch_end_times: dict[int, float] = field(init=False) + + # Memory metrics + peak_cpu_memory: int = 0 + peak_gpu_memory: dict[int, int] = field(init=False) + + # Progress metrics + total_steps: int = 0 + current_epoch: int = 0 + current_step: int = 0 + + def __post_init__(self): + """Initialize empty metric mappings.""" + self.epoch_start_times = {} + self.epoch_end_times = {} + self.peak_gpu_memory = {} + + @property + def elapsed_time(self) -> float: + """Calculate total elapsed time in seconds.""" + return time.time() - self.start_time + + def epoch_time(self, epoch: int) -> float | None: + """Calculate time taken for a specific epoch in seconds.""" + if epoch in self.epoch_start_times and epoch in self.epoch_end_times: + return self.epoch_end_times[epoch] - self.epoch_start_times[epoch] + + return None + + def average_epoch_time(self) -> float | None: + """Calculate average time per epoch in seconds.""" + completed_epochs = [ + epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times + ] + if not completed_epochs: + return None + + total_time = 0.0 + for epoch in completed_epochs: + epoch_time = self.epoch_time(epoch) + if epoch_time is not None: # Check to avoid mypy warning + total_time += epoch_time + + return total_time / len(completed_epochs) + + def steps_per_second(self) -> float | None: + """Calculate average steps per second across all training.""" + if self.total_steps == 0 or self.elapsed_time == 0: + return None + + return self.total_steps / self.elapsed_time + + def to_dict(self) -> dict[str, Any]: + """Convert metrics to a dictionary for telemetry reporting.""" + metrics = { + "total_time_seconds": self.elapsed_time, + "total_steps": self.total_steps, + "steps_per_second": self.steps_per_second(), + "epochs_completed": len( + [ + epoch + for epoch in self.epoch_start_times + if epoch in self.epoch_end_times + ] + ), + "peak_cpu_memory_bytes": self.peak_cpu_memory, + } + + # Add per-epoch timing if available + epoch_times: dict[str, float] = {} + for epoch in sorted(self.epoch_end_times.keys()): + time_taken = self.epoch_time(epoch) + if time_taken is not None: + epoch_times[f"epoch_{epoch}_seconds"] = time_taken + + if epoch_times: + metrics["epoch_times"] = epoch_times # type: ignore + metrics["average_epoch_time_seconds"] = self.average_epoch_time() + + # Add GPU memory metrics if available + if self.peak_gpu_memory: + gpu_metrics: dict[str, int] = {} + for gpu_id, memory in self.peak_gpu_memory.items(): + gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory + metrics["gpu_memory"] = gpu_metrics # type: ignore + + return metrics + + +class RuntimeMetricsTracker: + """Tracker for runtime metrics during training.""" + + update_interval = 100 + + def __init__(self): + """Initialize the runtime metrics tracker.""" + self.metrics = RuntimeMetrics(start_time=time.time()) + self.telemetry_manager = TelemetryManager.get_instance() + self._process = psutil.Process() + + def start_epoch(self, epoch: int): + """Record the start of a new epoch.""" + self.metrics.current_epoch = epoch + self.metrics.epoch_start_times[epoch] = time.time() + self.update_memory_metrics() + + def end_epoch(self, epoch: int): + """Record the end of an epoch.""" + self.metrics.epoch_end_times[epoch] = time.time() + + def update_step(self, step: int): + """Update the current step count.""" + self.metrics.current_step = step + self.metrics.total_steps += 1 + + # Periodically update memory metrics + if step % self.update_interval == 0: + self.update_memory_metrics() + + def _get_allocated_memory(self) -> dict[int, int]: + """ + Helper function for getting accelerator-agnostic allocated memory. + + Returns: + A dictionary mapping device IDs to allocated memory in bytes + """ + memory_used: dict[int, int] = {} + + # NVIDIA GPUs + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + memory_used[i] = torch.cuda.memory_allocated(i) + + # AMD GPUs + elif hasattr(torch, "hip") and torch.hip.is_available(): + for i in range(torch.hip.device_count()): + if hasattr(torch.hip, "memory_allocated"): + memory_used[i] = torch.hip.memory_allocated(i) + + # Apple Silicon + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + # MPS doesn't have per-device memory stats since there's only one device + if hasattr(torch.mps, "current_allocated_memory"): + memory_used[0] = torch.mps.current_allocated_memory() + + # Intel GPUs + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + for i in range(torch.xpu.device_count()): + if hasattr(torch.xpu, "memory_allocated"): + memory_used[i] = torch.xpu.memory_allocated(i) + + # NPUs + elif hasattr(torch, "npu") and torch.npu.is_available(): + for i in range(torch.npu.device_count()): + if hasattr(torch.npu, "memory_allocated"): + memory_used[i] = torch.npu.memory_allocated(i) + + return memory_used + + def update_memory_metrics(self): + """Update peak memory usage metrics.""" + # CPU memory + cpu_memory = self._process.memory_info().rss + self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory) + + # GPU memory (if available) + memory_used = self._get_allocated_memory() + for i, memory in memory_used.items(): + self.metrics.peak_gpu_memory[i] = max( + self.metrics.peak_gpu_memory.get(i, 0), memory + ) + + def get_memory_metrics(self) -> dict[str, Any]: + """Get the current memory metrics as a dictionary.""" + memory_metrics = { + "cpu_memory_bytes": self._process.memory_info().rss, + "peak_cpu_memory_bytes": self.metrics.peak_cpu_memory, + } + + # GPU memory (if available) + memory_used = self._get_allocated_memory() + for i, memory in memory_used.items(): + memory_metrics[f"gpu_{i}_memory_bytes"] = memory + memory_metrics[f"gpu_{i}_peak_memory_bytes"] = ( + self.metrics.peak_gpu_memory.get(i, 0) + ) + + return memory_metrics diff --git a/src/axolotl/telemetry/whitelist.yaml b/src/axolotl/telemetry/whitelist.yaml new file mode 100644 index 000000000..6c94d6e79 --- /dev/null +++ b/src/axolotl/telemetry/whitelist.yaml @@ -0,0 +1,33 @@ +organizations: + - "axolotl-ai-co" + - "meta-llama" + - "huggingface" + - "nvidia" + - "facebook" + - "google" + - "microsoft" + - "deepseek-ai" + - "HuggingFaceTB" + - "mistralai" + - "Qwen" + - "unsloth" + - "NousResearch" + - "allenai" + - "amd" + - "tiiuae" + - "tencent" + - "zai-org" + - "openai" + - "ibm-granite" + - "arcee-ai" + - "swiss-ai" + - "CohereForAI" + - "deepcogito" + - "THUDM" + - "ai21labs" + - "LiquidAI" + - "canopylabs" + - "state-spaces" + - "mistral-community" + - "llava-hf" + - "ByteDance-Seed" diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 441c50871..cce3b8a6a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -31,6 +31,8 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module ) from axolotl.integrations.base import PluginManager from axolotl.loaders import ModelLoader, load_processor, load_tokenizer +from axolotl.telemetry.errors import send_errors +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -45,6 +47,9 @@ if typing.TYPE_CHECKING: LOG = get_logger(__name__) +TELEMETRY_MANAGER = TelemetryManager.get_instance() +PLUGIN_MANAGER = PluginManager.get_instance() + def setup_model_and_tokenizer( cfg: DictDefault, @@ -62,7 +67,10 @@ def setup_model_and_tokenizer( `None`), and processor (if multimodal, else `None`). """ # Load tokenizer - LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) tokenizer = load_tokenizer(cfg) # Load processor for multimodal models if needed @@ -78,6 +86,14 @@ def setup_model_and_tokenizer( if model.generation_config is not None: model.generation_config.do_sample = True + TELEMETRY_MANAGER.send_event( + event_type="model-load", properties=model.config.to_dict() + ) + if peft_config: + TELEMETRY_MANAGER.send_event( + event_type="peft-config-load", properties=peft_config.to_dict() + ) + # Apply freezing if specified if cfg.unfrozen_parameters: freeze_layers_except(model, cfg.unfrozen_parameters) @@ -196,8 +212,7 @@ def execute_training( LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) - plugin_manager = PluginManager.get_instance() - plugin_manager.post_train(cfg, trainer.model) + PLUGIN_MANAGER.post_train(cfg, trainer.model) def save_trained_model( @@ -521,9 +536,7 @@ def setup_model_and_trainer( model_ref=model_ref, peft_config=peft_config, ) - - plugin_manager = PluginManager.get_instance() - plugin_manager.post_trainer_create(cfg, trainer) + PLUGIN_MANAGER.post_trainer_create(cfg, trainer) if cfg.use_ray: try: @@ -545,6 +558,7 @@ def setup_model_and_trainer( ) +@send_errors def train( cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]: @@ -595,5 +609,6 @@ def train( create_model_card(cfg, trainer) if not cfg.use_ray: cleanup_distributed() + PLUGIN_MANAGER.post_train(cfg, model) return model, tokenizer, trainer diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 5ad55f8b7..c9b087ea3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1069,7 +1069,7 @@ class AxolotlInputConfig( class AxolotlConfigWCapabilities(AxolotlInputConfig): - """wrapper to valdiate GPU capabilities with the configured options""" + """Wrapper to valdiate GPU capabilities with the configured options""" capabilities: GPUCapabilities env_capabilities: EnvCapabilities diff --git a/tests/conftest.py b/tests/conftest.py index 98847ebad..d3b9407ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,4 @@ -""" -shared pytest fixtures -""" +"""Shared pytest fixtures""" import functools import importlib @@ -582,3 +580,9 @@ def test_load_fixtures( download_llama2_model_fixture, ): pass + + +@pytest.fixture(autouse=True) +def disable_telemetry(monkeypatch): + monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1") + yield diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/telemetry/conftest.py b/tests/telemetry/conftest.py new file mode 100644 index 000000000..47776ce90 --- /dev/null +++ b/tests/telemetry/conftest.py @@ -0,0 +1,9 @@ +"""Shared pytest fixtures for telemetry tests.""" + +import pytest + + +@pytest.fixture(autouse=True) +def del_track_env(monkeypatch): + monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False) + yield diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py new file mode 100644 index 000000000..97d56a9c6 --- /dev/null +++ b/tests/telemetry/test_callbacks.py @@ -0,0 +1,373 @@ +"""Tests for telemetry callback module.""" + +# pylint: disable=redefined-outer-name + +import time +from unittest.mock import MagicMock, patch + +import pytest +from transformers import TrainerControl, TrainerState, TrainingArguments + +from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback + + +def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0): + """Calculate expected metrics values for tests""" + time_diff = current_time - last_time + step_diff = step - last_step + return { + "steps_per_second": ( + step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0 + ), + "time_since_last_report": time_diff, + "elapsed_time": current_time - start_time, + } + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("axolotl.telemetry.callbacks.time") as mock_time: + mock_time.time.return_value = 1000.0 + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_runtime_metrics_tracker(): + """Create a mock RuntimeMetricsTracker""" + with patch( + "axolotl.telemetry.callbacks.RuntimeMetricsTracker" + ) as mock_tracker_class: + mock_tracker = MagicMock() + # Set up metrics property on the tracker + mock_metrics = MagicMock() + mock_metrics.to_dict.return_value = { + "total_steps": 100, + "peak_cpu_memory_bytes": 1024, + } + mock_tracker.metrics = mock_metrics + + # Make the constructor return our mock + mock_tracker_class.return_value = mock_tracker + yield mock_tracker + + +@pytest.fixture +def training_args(): + """Create a minimal TrainingArguments instance""" + return TrainingArguments(output_dir="./output") + + +@pytest.fixture +def trainer_state(): + """Create a mock TrainerState""" + state = MagicMock(spec=TrainerState) + state.global_step = 10 + state.epoch = 0.5 # halfway through first epoch + state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}] + return state + + +@pytest.fixture +def trainer_control(): + """Create a mock TrainerControl""" + return MagicMock(spec=TrainerControl) + + +# pylint: disable=unused-argument +@pytest.fixture +def callback(mock_telemetry_manager, mock_runtime_metrics_tracker): + """Create a TelemetryCallback instance with mocked dependencies""" + return TelemetryCallback() + + +class TestTelemetryCallback: + """Tests for the TelemetryCallback class.""" + + def test_initialization(self, callback, mock_runtime_metrics_tracker): + """Test callback initialization.""" + assert callback.current_epoch == -1 + assert callback.tracker == mock_runtime_metrics_tracker + assert callback.last_report_step == 0 + assert hasattr(callback, "start_time") + assert hasattr(callback, "last_report_time") + assert callback.report_interval_steps == 100 + + def test_on_train_begin( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_begin sends expected event.""" + callback.on_train_begin(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once_with( + event_type="train-start" + ) + + def test_on_train_end( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_end sends expected event with metrics.""" + callback.on_train_end(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + + assert call_args["event_type"] == "train-end" + assert "loss" in call_args["properties"] + assert call_args["properties"]["loss"] == 2.5 + assert "learning_rate" in call_args["properties"] + assert call_args["properties"]["learning_rate"] == 5e-5 + + # Check that metrics from RuntimeMetricsTracker are included + assert "total_steps" in call_args["properties"] + assert call_args["properties"]["total_steps"] == 100 + assert "peak_cpu_memory_bytes" in call_args["properties"] + assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024 + + def test_on_epoch_begin( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_begin updates epoch counter and calls tracker.""" + initial_epoch = callback.current_epoch + + callback.on_epoch_begin(training_args, trainer_state, trainer_control) + + assert callback.current_epoch == initial_epoch + 1 + mock_runtime_metrics_tracker.start_epoch.assert_called_once_with( + initial_epoch + 1 + ) + + def test_on_epoch_end( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_end calls tracker.""" + # Set current epoch + callback.current_epoch = 2 + + callback.on_epoch_end(training_args, trainer_state, trainer_control) + + mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2) + + def test_on_step_end_no_report( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end updates tracker but doesn't report if criteria not met.""" + # Set up state to avoid reporting + trainer_state.global_step = 42 # Not divisible by report_interval_steps + callback.last_report_step = 41 # Just 1 step since last report + callback.last_report_time = time.time() # Just now + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(42) + + # Should not send telemetry + mock_telemetry_manager.send_event.assert_not_called() + + # Should not update last report time/step + assert callback.last_report_step == 41 + + def test_on_step_end_report_interval_steps( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when step interval is reached.""" + # Set up state with clear values + current_step = 100 # Exactly matches report_interval_steps + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + time_diff = current_time - start_time # 100 seconds + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step) + mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once() + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + assert call_args["event_type"] == "train-progress" + + # Properties should include expected values + props = call_args["properties"] + assert props["step"] == current_step + assert props["elapsed_time"] == time_diff # 1000 - 900 = 100 + assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100 + assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds + + # Should update last report time/step + assert callback.last_report_step == current_step + assert callback.last_report_time == current_time + + def test_on_step_end_report_time_elapsed( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when enough time has elapsed.""" + # Set up state with clear values + current_step = 120 + last_step = 10 + start_time = 900.0 + current_time = 1000.0 + time_diff = TIME_SINCE_LAST + 1 # Just over the threshold + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = current_time - time_diff + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, current_time - time_diff, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + assert ( + props["time_since_last_report"] + == expected_metrics["time_since_last_report"] + ) + + def test_on_step_end_first_step( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end always reports on first step.""" + # Set up state with clear values + current_step = 1 # First step + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + last_report_time = 999.0 # Just 1 second ago + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = last_report_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry even though not much time has passed + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values for first step + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["step"] == current_step + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, last_report_time, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + + def test_log_history_empty( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test handling of empty log history.""" + # Set up state with clear values + current_step = 1 + start_time = 900.0 + current_time = 1000.0 + + # Configure state and callback + trainer_state.global_step = current_step + trainer_state.log_history = [] + callback.start_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should still send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should have default values for missing log data + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["loss"] == 0 + assert props["learning_rate"] == 0 diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py new file mode 100644 index 000000000..2f0510b21 --- /dev/null +++ b/tests/telemetry/test_errors.py @@ -0,0 +1,341 @@ +"""Tests for telemetry error utilities""" + +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.errors import sanitize_stack_trace, send_errors + + +@pytest.fixture(autouse=True) +def reset_error_flag(monkeypatch): + """Reset ERROR_HANDLED flag using monkeypatch""" + import axolotl.telemetry.errors + + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + yield + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + + +@pytest.fixture +def example_stack_trace(): + """Provide a sample stack trace with mixed paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model + raise ValueError("Model path not found") +ValueError: Model path not found +""" + + +@pytest.fixture +def windows_stack_trace(): + """Provide a sample stack trace with Windows paths""" + return """Traceback (most recent call last): + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main + trainer = get_trainer(cfg) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained + raise ValueError(f"Unrecognized configuration class {config.__class__}") +ValueError: Unrecognized configuration class +""" + + +@pytest.fixture +def mixed_stack_trace(): + """Provide a sample stack trace with both axolotl and non-axolotl paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train + self._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop + super()._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def venv_stack_trace(): + """Provide a sample stack trace with virtual environment paths""" + return """Traceback (most recent call last): + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train + self._inner_training_loop() + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop + self.accelerator.backward(loss) + File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward + self.scaler.scale(loss).backward(**kwargs) + File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def dist_packages_stack_trace(): + """Provide a sample stack trace with dist-packages paths""" + return """Traceback (most recent call last): + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data + data = self._dataset_fetcher.fetch(index) + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__ + raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.") +IndexError: Index 10000 out of range for dataset of length 9832. +""" + + +@pytest.fixture +def project_stack_trace(): + """Provide a sample stack trace from a project directory (not a virtual env)""" + return """Traceback (most recent call last): + File "/home/user/projects/myproject/run.py", line 25, in + main() + File "/home/user/projects/myproject/src/cli.py", line 45, in main + app.run() + File "/home/user/projects/myproject/src/app.py", line 102, in run + raise ValueError("Configuration missing") +ValueError: Configuration missing +""" + + +def test_sanitize_stack_trace(example_stack_trace): + """Test that sanitize_stack_trace properly preserves axolotl paths""" + sanitized = sanitize_stack_trace(example_stack_trace) + + # Check that personal paths are removed + assert "/home/user" not in sanitized + assert ".local/lib/python3.9" not in sanitized + + # Check that site-packages is preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/axolotl/train.py" in sanitized + assert "site-packages/axolotl/utils/models.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Model path not found" in sanitized + + +def test_sanitize_windows_paths(windows_stack_trace): + """Test that sanitize_stack_trace handles Windows paths""" + sanitized = sanitize_stack_trace(windows_stack_trace) + + # Check that personal paths are removed + assert "C:\\Users\\name" not in sanitized + assert "AppData\\Local\\Programs\\Python" not in sanitized + + # Check that both axolotl and transformers packages are preserved + assert ( + "site-packages\\axolotl\\cli\\train.py" in sanitized + or "site-packages/axolotl/cli/train.py" in sanitized + ) + assert ( + "site-packages\\axolotl\\train.py" in sanitized + or "site-packages/axolotl/train.py" in sanitized + ) + assert ( + "site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized + or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized + ) + + # Check that error message is preserved + assert "ValueError: Unrecognized configuration class" in sanitized + + +def test_sanitize_mixed_paths(mixed_stack_trace): + """Test that sanitize_stack_trace preserves all package paths""" + sanitized = sanitize_stack_trace(mixed_stack_trace) + + # Check that all package paths are preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/axolotl/utils/trainer.py" in sanitized + assert "site-packages/torch/utils/data/dataloader.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_venv_paths(venv_stack_trace): + """Test that sanitize_stack_trace preserves virtual environment package paths""" + sanitized = sanitize_stack_trace(venv_stack_trace) + + # Check that personal paths are removed + assert "/home/user/venv" not in sanitized + + # Check that all package paths are preserved + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/accelerate/accelerator.py" in sanitized + assert "site-packages/torch/_tensor.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_dist_packages(dist_packages_stack_trace): + """Test that sanitize_stack_trace preserves dist-packages paths""" + sanitized = sanitize_stack_trace(dist_packages_stack_trace) + + # Check that system paths are removed + assert "/usr/local/lib/python3.8" not in sanitized + + # Check that all package paths are preserved + assert "dist-packages/torch/utils/data/dataloader.py" in sanitized + assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized + assert "dist-packages/datasets/arrow_dataset.py" in sanitized + + # Check that error message is preserved + assert ( + "IndexError: Index 10000 out of range for dataset of length 9832." in sanitized + ) + + +def test_sanitize_project_paths(project_stack_trace): + """Test handling of project paths (non-virtual env)""" + sanitized = sanitize_stack_trace(project_stack_trace) + + # Check that personal paths are removed + assert "/home/user/projects" not in sanitized + + # For non-package paths, we should at least preserve the filename + assert "run.py" in sanitized + assert "cli.py" in sanitized + assert "app.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Configuration missing" in sanitized + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +def test_send_errors_successful_execution(mock_telemetry_manager): + """Test that send_errors doesn't send telemetry for successful function execution""" + + @send_errors + def test_func(): + return "success" + + result = test_func() + assert result == "success" + mock_telemetry_manager.send_event.assert_not_called() + + +def test_send_errors_with_exception(mock_telemetry_manager): + """Test that send_errors sends telemetry when an exception occurs""" + test_error = ValueError("Test error") + + @send_errors + def test_func(): + raise test_error + + with pytest.raises(ValueError) as excinfo: + test_func() + + assert excinfo.value == test_error + mock_telemetry_manager.send_event.assert_called_once() + + # Check that the error info was passed correctly + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "test_func-error" in call_args["event_type"] + assert "Test error" in call_args["properties"]["exception"] + assert "stack_trace" in call_args["properties"] + + +def test_send_errors_nested_calls(mock_telemetry_manager): + """Test that send_errors only sends telemetry once for nested decorated functions""" + + @send_errors + def inner_func(): + raise ValueError("Inner error") + + @send_errors + def outer_func(): + return inner_func() + + with pytest.raises(ValueError): + outer_func() + + # Telemetry should be sent only once for the inner function + assert mock_telemetry_manager.send_event.call_count == 1 + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "inner_func-error" in call_args["event_type"] + + +def test_send_errors_telemetry_disable(): + """Test that send_errors doesn't attempt to send telemetry when disabled""" + + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = False + mock_manager_class.get_instance.return_value = mock_manager + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + mock_manager.send_event.assert_not_called() + + +def test_error_handled_reset(): + """Test that ERROR_HANDLED flag is properly reset""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + # Create and configure the mock manager + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + + from axolotl.telemetry.errors import ERROR_HANDLED + + @send_errors + def test_func(): + raise ValueError("Test error") + + assert not ERROR_HANDLED + + with pytest.raises(ValueError): + test_func() + + from axolotl.telemetry.errors import ERROR_HANDLED + + assert ERROR_HANDLED + + +def test_module_path_resolution(mock_telemetry_manager): + """Test that the module path is correctly resolved for the event type""" + import inspect + + current_module = inspect.getmodule(test_module_path_resolution).__name__ + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + assert mock_telemetry_manager.send_event.called + event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] + + expected_event_type = f"{current_module}.test_func-error" + assert expected_event_type == event_type diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py new file mode 100644 index 000000000..2eeae2f11 --- /dev/null +++ b/tests/telemetry/test_manager.py @@ -0,0 +1,275 @@ +"""Tests for TelemetryManager class and utilities""" + +# pylint: disable=redefined-outer-name,protected-access + +import os +from unittest.mock import patch + +import pytest +import yaml + +from axolotl.telemetry.manager import TelemetryManager + + +@pytest.fixture +def mock_whitelist(tmp_path): + """Create a temporary whitelist file for testing""" + whitelist_content = { + "organizations": ["meta-llama", "mistralai"], + } + whitelist_file = tmp_path / "whitelist.yaml" + with open(whitelist_file, "w", encoding="utf-8") as f: + yaml.dump(whitelist_content, f) + + return str(whitelist_file) + + +@pytest.fixture +def telemetry_manager_class(): + """Reset the TelemetryManager singleton between tests""" + original_instance = TelemetryManager._instance + original_initialized = TelemetryManager._initialized + TelemetryManager._instance = None + TelemetryManager._initialized = False + yield TelemetryManager + TelemetryManager._instance = original_instance + TelemetryManager._initialized = original_initialized + + +@pytest.fixture +def manager(telemetry_manager_class, mock_whitelist): + """Create a TelemetryManager instance with mocked dependencies""" + with ( + patch("posthog.capture"), + patch("posthog.flush"), + patch("time.sleep"), + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"RANK": "0"}), + ): + manager = telemetry_manager_class() + # Manually enable for most tests + manager.enabled = True + return manager + + +def test_singleton_instance(telemetry_manager_class): + """Test that TelemetryManager is a singleton""" + with ( + patch("posthog.capture"), + patch("time.sleep"), + patch.dict(os.environ, {"RANK": "0"}), + ): + first = telemetry_manager_class() + second = telemetry_manager_class() + assert first is second + assert telemetry_manager_class.get_instance() is first + + +def test_telemetry_enabled_by_default(telemetry_manager_class): + """Test that telemetry is enabled by default (opt-out)""" + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("time.sleep"), + patch("logging.Logger.info"), + ): + manager = telemetry_manager_class() + assert manager.enabled + + +def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class): + """Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert manager.enabled + + +def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when DO_NOT_TRACK=1""" + with ( + patch.dict( + os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"} + ), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): + """Test that telemetry is disabled for non-main processes""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_opt_in_info_displayed(telemetry_manager_class): + """Test that opt-in info is displayed when telemetry is not configured""" + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("logging.Logger.warning") as mock_warning, + patch("time.sleep"), + ): + telemetry_manager_class() + assert any( + "Telemetry is now enabled by default" in str(call) + for call in mock_warning.call_args_list + ) + + +def test_is_whitelisted(telemetry_manager_class, mock_whitelist): + """Test org whitelist functionality""" + with ( + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + + # Should match organizations from the mock whitelist + assert manager._is_whitelisted("meta-llama/llama-7b") + assert manager._is_whitelisted("mistralai/mistral-7b-instruct") + # Should not match + assert not manager._is_whitelisted("unknown/model") + # Should handle case insensitively + assert manager._is_whitelisted("META-LLAMA/Llama-7B") + # Should handle empty input + assert not manager._is_whitelisted("") + + +def test_system_info_collection(manager): + """Test system information collection""" + system_info = manager._get_system_info() + + # Check essential keys + assert "os" in system_info + assert "python_version" in system_info + assert "cpu_count" in system_info + assert "memory_total" in system_info + assert "accelerator_count" in system_info + + +def test_send_event(telemetry_manager_class): + """Test basic event sending""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + + # Test with clean properties (no PII) + manager.send_event("test_event", {"key": "value"}) + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "test_event" + assert mock_capture.call_args[1]["properties"] == {"key": "value"} + assert mock_capture.call_args[1]["distinct_id"] == manager.run_id + + # Test with default properties (None) + mock_capture.reset_mock() + manager.send_event("simple_event") + assert mock_capture.called + assert mock_capture.call_args[1]["properties"] == {} + + +def test_send_system_info(telemetry_manager_class): + """Test sending system info""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + manager.send_system_info() + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "system-info" + assert mock_capture.call_args[1]["properties"] == manager.system_info + + +def test_redacted_properties(telemetry_manager_class): + """Test path redaction in send_event method""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + # Test with properties containing various paths and non-paths + test_properties = { + "filepath": "/home/user/sensitive/data.txt", + "windows_path": "C:\\Users\\name\\Documents\\project\\file.py", + "output_dir": "/var/lib/data", + "path_to_model": "models/llama/7b", + "message": "Training started", # Should not be redacted + "metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted + "base_model": "models/local_model", + "nested": { + "model_path": "/models/my_model", + "root_dir": "/home/user/projects", + "stats": {"steps": 1000, "epochs": 3}, # Should not be redacted + }, + } + + manager.send_event("test_event", test_properties) + + # Verify the call was made + assert mock_capture.called + + # Get the sanitized properties that were sent + sanitized = mock_capture.call_args[1]["properties"] + + # Check that path-like and base_model keys were redacted + assert sanitized["filepath"] == "[REDACTED]" + assert sanitized["windows_path"] == "[REDACTED]" + assert sanitized["path_to_model"] == "[REDACTED]" + assert sanitized["base_model"] == "[REDACTED]" + + # Check that non-path values were preserved + assert sanitized["message"] == "Training started" + assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95} + + # Check nested structure handling + assert sanitized["nested"]["model_path"] == "[REDACTED]" + assert sanitized["nested"]["root_dir"] == "[REDACTED]" + assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3} + + +def test_disable_telemetry(manager): + """Test that disabled telemetry doesn't send events""" + with patch("posthog.capture") as mock_capture: + manager.enabled = False + manager.send_event("test_event") + assert not mock_capture.called + + +def test_exception_handling_during_send(manager): + """Test that exceptions in PostHog are handled gracefully""" + with ( + patch("posthog.capture", side_effect=Exception("Test error")), + patch("logging.Logger.warning") as mock_warning, + ): + manager.send_event("test_event") + warning_logged = False + for call in mock_warning.call_args_list: + if "Failed to send telemetry event" in str(call): + warning_logged = True + break + assert warning_logged + + +def test_shutdown(manager): + """Test shutdown behavior""" + with patch("posthog.shutdown") as mock_shutdown: + manager.shutdown() + assert mock_shutdown.called diff --git a/tests/telemetry/test_runtime_metrics.py b/tests/telemetry/test_runtime_metrics.py new file mode 100644 index 000000000..c8916e072 --- /dev/null +++ b/tests/telemetry/test_runtime_metrics.py @@ -0,0 +1,357 @@ +"""Tests for runtime metrics telemetry module""" + +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("time.time") as mock_time: + # Start with time 1000.0 and increment by 10 seconds on each call + times = [1000.0 + i * 10 for i in range(10)] + mock_time.side_effect = times + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch( + "axolotl.telemetry.runtime_metrics.TelemetryManager" + ) as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_psutil(): + """Mock psutil for memory information""" + with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil: + mock_process = MagicMock() + mock_memory_info = MagicMock() + # Set initial memory to 1GB + mock_memory_info.rss = 1024 * 1024 * 1024 + mock_process.memory_info.return_value = mock_memory_info + mock_psutil.Process.return_value = mock_process + yield mock_psutil + + +@pytest.fixture +def mock_torch(): + """Mock torch.cuda functions""" + with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.device_count.return_value = 2 + + # Mock memory allocated per device (1GB for device 0, 2GB for device 1) + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 1) * 1024 * 1024 * 1024 + ) + + yield mock_torch + + +class TestRuntimeMetrics: + """Tests for RuntimeMetrics class.""" + + def test_initialization(self): + """Test RuntimeMetrics initialization.""" + metrics = RuntimeMetrics(start_time=1000.0) + + assert metrics.start_time == 1000.0 + assert metrics.epoch_start_times == {} + assert metrics.epoch_end_times == {} + assert metrics.peak_gpu_memory == {} + assert metrics.total_steps == 0 + assert metrics.current_epoch == 0 + assert metrics.current_step == 0 + assert metrics.peak_cpu_memory == 0 + + def test_elapsed_time(self, mock_time): + """Test elapsed_time property.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # Mock time.time() to return 1050.0 + mock_time.side_effect = [1050.0] + + assert metrics.elapsed_time == 50.0 + + def test_epoch_time(self): + """Test epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No epoch data + assert metrics.epoch_time(0) is None + + # Add epoch start but no end + metrics.epoch_start_times[0] = 1000.0 + assert metrics.epoch_time(0) is None + + # Add epoch end + metrics.epoch_end_times[0] = 1060.0 + assert metrics.epoch_time(0) == 60.0 + + def test_average_epoch_time(self): + """Test average_epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No completed epochs + assert metrics.average_epoch_time() is None + + # Add one completed epoch + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + assert metrics.average_epoch_time() == 60.0 + + # Add second completed epoch + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 # 80 seconds + assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80 + + # Add incomplete epoch (should not affect average) + metrics.epoch_start_times[2] = 1140.0 + assert metrics.average_epoch_time() == 70.0 + + def test_steps_per_second(self, mock_time): + """Test steps_per_second method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No steps - first call to time.time() + mock_time.side_effect = None + mock_time.return_value = 1050.0 + assert metrics.steps_per_second() is None + + # Add steps - second call to time.time() + metrics.total_steps = 100 + mock_time.return_value = 1050.0 # Keep same time for consistent result + assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds + + def test_to_dict_basic(self, mock_time): + """Test to_dict method with basic metrics.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1050.0 + + result = metrics.to_dict() + + assert result["total_time_seconds"] == 50.0 + assert result["total_steps"] == 100 + assert result["steps_per_second"] == 2.0 + assert result["epochs_completed"] == 0 + assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + assert "epoch_times" not in result + assert "gpu_memory" not in result + + def test_to_dict_with_epochs(self, mock_time): + """Test to_dict method with epoch data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + + # Add epoch data + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1150.0 + + result = metrics.to_dict() + + assert "epoch_times" in result + assert result["epoch_times"]["epoch_0_seconds"] == 60.0 + assert result["epoch_times"]["epoch_1_seconds"] == 80.0 + assert result["average_epoch_time_seconds"] == 70.0 + + def test_to_dict_with_gpu_memory(self, mock_time): + """Test to_dict method with GPU memory data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.peak_gpu_memory = { + 0: 1 * 1024 * 1024 * 1024, # 1GB + 1: 2 * 1024 * 1024 * 1024, # 2GB + } + + # Mock elapsed_time + mock_time.side_effect = [1050.0] + + result = metrics.to_dict() + + assert "gpu_memory" in result + assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024 + assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024 + + +class TestRuntimeMetricsTracker: + """Tests for RuntimeMetricsTracker class.""" + + # pylint: disable=unused-argument + def test_initialization(self, mock_time, mock_telemetry_manager): + """Test RuntimeMetricsTracker initialization.""" + tracker = RuntimeMetricsTracker() + + assert isinstance(tracker.metrics, RuntimeMetrics) + assert tracker.metrics.start_time == 1000.0 # First value from mock_time + + # pylint: disable=unused-argument + def test_start_epoch( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test start_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Reset mock_time to control next value + mock_time.side_effect = [1010.0] + + tracker.start_epoch(0) + + assert tracker.metrics.current_epoch == 0 + assert tracker.metrics.epoch_start_times[0] == 1010.0 + + # Verify memory metrics were updated + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert 0 in tracker.metrics.peak_gpu_memory + assert 1 in tracker.metrics.peak_gpu_memory + + # pylint: disable=unused-argument + def test_end_epoch(self, mock_time, mock_telemetry_manager): + """Test end_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Start epoch 0 + mock_time.side_effect = [1010.0] + tracker.start_epoch(0) + + # End epoch 0 + mock_time.side_effect = [1060.0] + tracker.end_epoch(0) + + assert 0 in tracker.metrics.epoch_end_times + assert tracker.metrics.epoch_end_times[0] == 1060.0 + + # pylint: disable=unused-argument + def test_update_step( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_step method.""" + tracker = RuntimeMetricsTracker() + + # Update step to a non-multiple of 100 + tracker.update_step(42) + + assert tracker.metrics.current_step == 42 + assert tracker.metrics.total_steps == 1 + + # Memory metrics should not be updated for non-multiple of 100 + assert tracker.metrics.peak_cpu_memory == 0 + + # Update step to a multiple of 100 + tracker.update_step(100) + + assert tracker.metrics.current_step == 100 + assert tracker.metrics.total_steps == 2 + + # Memory metrics should be updated for multiple of 100 + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_update_memory_metrics( + self, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Initial memory state + assert tracker.metrics.peak_cpu_memory == 0 + assert tracker.metrics.peak_gpu_memory == {} + + # Update memory metrics + tracker.update_memory_metrics() + + # Verify CPU memory + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # Verify GPU memory + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be lower + mock_process = mock_psutil.Process.return_value + mock_memory_info = mock_process.memory_info.return_value + mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 0.5) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should not decrease + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be higher + mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 2) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should increase + assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager): + """Test get_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Set peak memory values + tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 + tracker.metrics.peak_gpu_memory = { + 0: 3 * 1024 * 1024 * 1024, + 1: 4 * 1024 * 1024 * 1024, + } + + # Get memory metrics + memory_metrics = tracker.get_memory_metrics() + + # Verify CPU memory + assert ( + memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Peak value we set + + # Verify GPU memory + assert ( + memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024 + ) # Peak value we set + assert ( + memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024 + ) # Peak value we set From f5f21fb2161071920dd6a4d5f9c4b74da7d4dc71 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 18 Nov 2025 14:45:21 +0700 Subject: [PATCH 10/14] chore: update readme with latest updates (#3267) --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d6dd67988..c86ab8f4a 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ ## 🎉 Latest Updates +- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). +- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). +- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). - 2025/07: - ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info. - Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm). @@ -36,12 +39,12 @@ - [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl! - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! - 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! -- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
Expand older updates +- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. - 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! - 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! - 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own! From 0d27e14e4538846e23fc0614ae0a0505c36e6eae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Nov 2025 09:04:37 -0500 Subject: [PATCH 11/14] Torch 2.9.1 base images (#3268) * update torch 2.9.1 base images * update base dockerfile image check --- .github/workflows/base.yml | 8 ++++---- docker/Dockerfile-base | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 2e8950dd9..eddce1438 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -57,14 +57,14 @@ jobs: cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-base" # - cuda: "128" @@ -146,14 +146,14 @@ jobs: cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-uv-base" steps: diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 25eae4fde..cfd30b851 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 cache purge -RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ +RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \ wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ From 0b635e69c51758492b97fad668316f6f3127ed4e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Nov 2025 09:26:24 -0500 Subject: [PATCH 12/14] build docker images for 2.9.x (#3273) --- .github/workflows/main.yml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4f0cc4c99..f34a0cf2f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,6 +36,16 @@ jobs: pytorch: 2.8.0 axolotl_extras: is_latest: true + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.0 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -109,6 +119,16 @@ jobs: pytorch: 2.8.0 axolotl_extras: is_latest: true + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.0 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout From 006f226270b83565971706b75032734aa865d345 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 24 Nov 2025 10:21:31 +0700 Subject: [PATCH 13/14] Feat: add Olmo3 (BC with Olmo and Olmo2) (#3275) * feat: update cce to include olmo family * chore: update docs following feedback * feat: add olmo3 config * fix: clarify 3 methods * chore: add olmo to readme --- README.md | 1 + docs/multi-gpu.qmd | 28 +++++--- .../colab-axolotl-example.ipynb | 2 +- examples/olmo3/README.md | 46 +++++++++++++ examples/olmo3/olmo3-7b-qlora.yaml | 64 +++++++++++++++++++ examples/seed-oss/README.md | 26 +++----- examples/smolvlm2/README.md | 4 +- scripts/cutcrossentropy_install.py | 2 +- .../integrations/cut_cross_entropy/README.md | 5 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/monkeypatch/multipack.py | 3 + 11 files changed, 150 insertions(+), 33 deletions(-) create mode 100644 examples/olmo3/README.md create mode 100644 examples/olmo3/olmo3-7b-qlora.yaml diff --git a/README.md b/README.md index c86ab8f4a..1517fb874 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ ## 🎉 Latest Updates +- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3). - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 57a941b04..1b58a108c 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -4,7 +4,7 @@ format: html: toc: true toc-depth: 3 - number-sections: true + # number-sections: true code-tools: true execute: enabled: false @@ -14,12 +14,18 @@ This guide covers advanced training configurations for multi-GPU setups using Ax ## Overview {#sec-overview} -Axolotl supports several methods for multi-GPU training: +When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy. -- DeepSpeed (recommended) -- FSDP (Fully Sharded Data Parallel) -- Sequence parallelism -- FSDP + QLoRA +You generally cannot combine these strategies; they are mutually exclusive. + +1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3. +2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended). +3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected). + +These features can often be combined with the strategies above: + +* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP). +* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP). ## DeepSpeed {#sec-deepspeed} @@ -65,12 +71,18 @@ Start from Stage 1 -> Stage 2 -> Stage 3. ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} +FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers. + ::: {.callout-note} FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl. ::: +### FSDP + QLoRA {#sec-fsdp-qlora} + +For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). + ### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2} To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and @@ -145,10 +157,6 @@ single sequence causes OOM errors during model training. See our [dedicated guide](sequence_parallelism.qmd) for more information. -### FSDP + QLoRA {#sec-fsdp-qlora} - -For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). - ## Performance Optimization {#sec-performance} ### Liger Kernel Integration {#sec-liger} diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index cea1aeda0..57a638948 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\"" ] }, { diff --git a/examples/olmo3/README.md b/examples/olmo3/README.md new file mode 100644 index 000000000..d4dbe05a9 --- /dev/null +++ b/examples/olmo3/README.md @@ -0,0 +1,46 @@ +# Finetune Allenai's Olmo 3 with Axolotl + +[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + + # Install Cut Cross Entropy + python scripts/cutcrossentropy_install.py | sh + ``` + +2. Run the finetuning example: + +```bash +axolotl train examples/olmo3/olmo3-7b-qlora.yaml +``` + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- The example config can be re-used for Olmo and Olmo 2. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Olmo 3 Blog](https://allenai.org/blog/olmo3) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/olmo3/olmo3-7b-qlora.yaml b/examples/olmo3/olmo3-7b-qlora.yaml new file mode 100644 index 000000000..c8878d79f --- /dev/null +++ b/examples/olmo3/olmo3-7b-qlora.yaml @@ -0,0 +1,64 @@ +base_model: allenai/Olmo-3-7B-Instruct-SFT + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/seed-oss/README.md b/examples/seed-oss/README.md index 5610c1316..aeb8635e3 100644 --- a/examples/seed-oss/README.md +++ b/examples/seed-oss/README.md @@ -6,21 +6,17 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations ## Getting started -1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). - Here is an example of how to install from main for pip: + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' -```bash -# Ensure you have Pytorch installed (Pytorch 2.6.0 min) -git clone https://github.com/axolotl-ai-cloud/axolotl.git -cd axolotl - -pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja -pip3 install --no-build-isolation -e '.[flash-attn]' - -# Install Cut Cross Entropy -python scripts/cutcrossentropy_install.py | sh -``` + # Install Cut Cross Entropy + python scripts/cutcrossentropy_install.py | sh + ``` 2. Run the finetuning example: @@ -41,9 +37,7 @@ Let us know how it goes. Happy finetuning! 🚀 ## Optimization Guides -- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) -- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) -- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). ## Related Resources diff --git a/examples/smolvlm2/README.md b/examples/smolvlm2/README.md index 9c0ae4836..74c1a1c0f 100644 --- a/examples/smolvlm2/README.md +++ b/examples/smolvlm2/README.md @@ -37,9 +37,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl. ## Optimization Guides -- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) -- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) -- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). ## Related Resources diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index cb498c002..91d0f45d6 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 5c7c5166b..4f98ac089 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953" ``` ## Usage @@ -65,6 +65,9 @@ plugins: - mistral3 - mixtral - mllama +- olmo +- olmo2 +- olmo3 - phi - phi3 - phi4_multimodal diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index bd0124b93..b8f7e9da3 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`' ) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 5d34f1935..fdda3c3bc 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -49,6 +49,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "seed_oss", "lfm2", "lfm2_moe", + "olmo", + "olmo2", + "olmo3", ] From 8990ca32058b61c65dfa60a0b8bfcf0ce624a75f Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 24 Nov 2025 12:18:53 +0530 Subject: [PATCH 14/14] fix: removed unused "scikit-learn==1.4.2" (#3277) Co-authored-by: Ved --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 977262df5..08759279d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,7 +42,6 @@ numpy>=2.2.6 # qlora things evaluate==0.4.1 scipy -scikit-learn==1.4.2 nvidia-ml-py==12.560.30 art tensorboard