From 6e2f5ccf9f03040e5de3252999aa0733fc88261b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 14 Oct 2025 10:21:49 -0400 Subject: [PATCH 01/16] chore: update pre-commit hooks (#3211) [skip ci] Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e853243cd..0e455f52c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.3 + rev: v0.14.0 hooks: - id: ruff args: [--fix] From 4cdfdfebb51d6a53d4468c6512b75c500ab85293 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Oct 2025 15:54:05 -0400 Subject: [PATCH 02/16] upgrade transformers==4.57.1 and peft==0.23.1 (#3214) --- requirements.txt | 4 ++-- tests/e2e/multigpu/test_llama.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9c56638a3..e1f1b10a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,11 +13,11 @@ packaging==23.2 huggingface_hub>=0.33.0 peft>=0.17.1 tokenizers>=0.21.1 -transformers==4.57.0 +transformers==4.57.1 accelerate==1.10.1 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.23.0 +trl==0.23.1 hf_xet==1.1.5 kernels==0.9.0 trackio diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index ffdbad942..3383e71d1 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -546,7 +546,6 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high" ) - @pytest.mark.skip("regression failure from v4.57.0") def test_fsdp_qlora_prequant_packed(self, temp_dir): cfg = DictDefault( { From aa1240acd8d7e9640a01a78e9da8a0725b158041 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 16 Oct 2025 16:07:27 +0700 Subject: [PATCH 03/16] fix: transformers deprecate load_in_Xbit in model_kwargs (#3205) * fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config --- src/axolotl/loaders/model.py | 16 ++-------------- tests/test_loaders.py | 28 +++++++++++++++++++--------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index f438d6b61..aeec46584 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -515,9 +515,6 @@ class ModelLoader: if self.cfg.model_quantization_config_kwargs: mxfp4_kwargs = self.cfg.model_quantization_config_kwargs self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) - else: - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): @@ -552,9 +549,7 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and self.model_kwargs.get( - "load_in_4bit", False - ): + elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -580,9 +575,7 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) - elif self.cfg.adapter == "lora" and self.model_kwargs.get( - "load_in_8bit", False - ): + elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit: bnb_config = { "load_in_8bit": True, } @@ -596,11 +589,6 @@ class ModelLoader: **bnb_config, ) - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in self.model_kwargs or self.cfg.gptq: - self.model_kwargs.pop("load_in_8bit", None) - self.model_kwargs.pop("load_in_4bit", None) - def _set_attention_config(self): """Sample packing uses custom FA2 patch""" if self.cfg.attn_implementation: diff --git a/tests/test_loaders.py b/tests/test_loaders.py index f516d0ca4..913090566 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -80,16 +80,26 @@ class TestModelsUtils: hasattr(self.model_loader.model_kwargs, "load_in_8bit") and hasattr(self.model_loader.model_kwargs, "load_in_4bit") ) - elif load_in_8bit and self.cfg.adapter is not None: - assert self.model_loader.model_kwargs["load_in_8bit"] - elif load_in_4bit and self.cfg.adapter is not None: - assert self.model_loader.model_kwargs["load_in_4bit"] - if (self.cfg.adapter == "qlora" and load_in_4bit) or ( - self.cfg.adapter == "lora" and load_in_8bit - ): - assert self.model_loader.model_kwargs.get( - "quantization_config", BitsAndBytesConfig + if self.cfg.adapter == "qlora" and load_in_4bit: + assert isinstance( + self.model_loader.model_kwargs.get("quantization_config"), + BitsAndBytesConfig, + ) + + assert ( + self.model_loader.model_kwargs["quantization_config"]._load_in_4bit + is True + ) + if self.cfg.adapter == "lora" and load_in_8bit: + assert isinstance( + self.model_loader.model_kwargs.get("quantization_config"), + BitsAndBytesConfig, + ) + + assert ( + self.model_loader.model_kwargs["quantization_config"]._load_in_8bit + is True ) def test_message_property_mapping(self): From 93ba57396f103778dc4e02cb954bdb46ef155fe2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 17 Oct 2025 10:35:03 +0700 Subject: [PATCH 04/16] fix: qwen3_vl attention config (#3216) --- src/axolotl/monkeypatch/lora_kernels.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index e845dc6ce..8e335fe4c 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -134,6 +134,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: return Qwen2Attention + if model_type == "qwen3_vl": + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention + + return Qwen3VLTextAttention + if model_type == "mllama": from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention From 87565ecc05f1b8fd1f8b907dd750d3a5d09adf9a Mon Sep 17 00:00:00 2001 From: Leonard Date: Fri, 17 Oct 2025 19:00:26 +0900 Subject: [PATCH 05/16] Add chat_template.argilla_chat support for DPO datasets (#3202) * Add chat_template.argilla_chat support for DPO datasets Creates a new chat_template.argilla_chat prompt strategy for handling DPO datasets where chosen/rejected fields contain full conversations (messages + final response), following the pattern of chatml.argilla_chat and llama3.argilla_chat. - Add argilla_chat() function to chat_template.py - Add chat_template.argilla_chat to RLHF documentation - Add test coverage for argilla_chat with multiple tokenizers Dataset format: { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } * Fix chat_template.argilla_chat return value contract and add docstring - Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn - Add remove_columns specification for field_chosen and field_rejected - Add comprehensive docstring with Args/Returns sections - Update tests to unpack tuple return value Addresses PR feedback to maintain consistency with chat_template.default() and properly specify columns to remove after dataset transformation. * Update tests/prompt_strategies/test_dpo_chat_templates.py Co-authored-by: Wing Lian --------- Co-authored-by: Wing Lian --- docs/rlhf.qmd | 15 +++ .../prompt_strategies/dpo/chat_template.py | 120 ++++++++++++++++++ .../test_dpo_chat_templates.py | 78 +++++++++++- 3 files changed, 212 insertions(+), 1 deletion(-) diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 4a67b7559..594ebc743 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format: } ``` +#### chat_template.argilla_chat + +```json +{ + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] +} +``` + #### chat_template.default ```yaml diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 85c4d2182..58b4d75bd 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs): return result return transform_fn, {"remove_columns": [field_messages]} + + +def argilla_chat(cfg, dataset_idx=0, **kwargs): + """ + DPO chat template strategy for argilla-style datasets. + + For argilla-style datasets where chosen/rejected contain full conversations + instead of single response messages. Extracts the conversation history from + the chosen field and formats both chosen/rejected responses using the + configured chat template. + + Args: + cfg: Configuration object containing chat_template and dataset settings + dataset_idx: Index of the dataset in the config (default: 0) + **kwargs: Additional keyword arguments (unused) + + Returns: + tuple: (transform_fn, dataset_kwargs) where: + - transform_fn: Function to transform dataset samples + - dataset_kwargs: Dict with 'remove_columns' specifying columns to drop + + Dataset format: + { + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } + """ + ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = handle_legacy_message_fields_logic(ds_cfg) + + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + message_property_mappings = ds_cfg.get( + "message_property_mappings", + { + "role": "role", + "content": "content", + }, + ) + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + chosen_raw = sample[field_chosen] + rejected_raw = sample[field_rejected] + + # Extract messages (all but last) and responses (last message) + chosen_messages = [ + { + "role": role_map[m[message_property_mappings["role"]]], + "content": m[message_property_mappings["content"]], + } + for m in chosen_raw[:-1] + ] + chosen_response = { + "role": role_map[chosen_raw[-1][message_property_mappings["role"]]], + "content": chosen_raw[-1][message_property_mappings["content"]], + } + + rejected_response = { + "role": role_map[rejected_raw[-1][message_property_mappings["role"]]], + "content": rejected_raw[-1][message_property_mappings["content"]], + } + + dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + chosen_messages, + add_generation_prompt=True, + chat_template=chat_template_string, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [dummy_user_message, chosen_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen_response["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() + + result["rejected"] = tokenizer.apply_chat_template( + [dummy_user_message, rejected_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected_response["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:].rstrip() + + return result + + return transform_fn, {"remove_columns": [field_chosen, field_rejected]} diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index e570cfc9d..b5c121726 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -8,7 +8,7 @@ import pytest from datasets import Dataset from transformers import AutoTokenizer -from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline @@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset(): ) +@pytest.fixture(name="argilla_chat_dataset") +def fixture_argilla_chat_dataset(): + return Dataset.from_list( + [ + { + "chosen": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "goodbye", + }, + ], + "rejected": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "party on", + }, + ], + } + ] + ) + + @pytest.fixture(name="phi3_tokenizer") @enable_hf_offline def fixture_phi3_tokenizer(): @@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma: assert result["rejected"] == "party on" +class TestArgillaChatDPOChatTemplate: + """ + Test class for argilla_chat style datasets (chosen/rejected contain full conversations). + """ + + def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n" + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + if __name__ == "__main__": unittest.main() From 8bb871b5cf0810fd4034069821250d718db366ca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 20 Oct 2025 14:06:58 +0700 Subject: [PATCH 06/16] fix: deepspeed with context parallel (#3220) --- .../monkeypatch/transformers/trainer_context_parallel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py index 74a35e83f..ba8b16dda 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py +++ b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py @@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' -PATCHED_GUARD = ( - 'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):' -) +PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):' def patch_prepare_context_parallel_inputs() -> None: From 383f220cfd658804f4c508a0686c988861ecffbe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Oct 2025 08:53:49 -0400 Subject: [PATCH 07/16] build torch 2.9.0 base images (#3221) --- .github/workflows/base.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 7af6059c8..b2681bb5d 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -53,6 +53,13 @@ jobs: pytorch: 2.8.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" + - cuda: "128" + cuda_version: 12.8.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.9.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + dockerfile: "Dockerfile-base" # - cuda: "128" # cuda_version: 12.8.1 # cudnn_version: "" @@ -129,6 +136,13 @@ jobs: pytorch: 2.8.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" + - cuda: "128" + cuda_version: 12.8.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.9.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + dockerfile: "Dockerfile-uv-base" steps: - name: Checkout uses: actions/checkout@v4 From 613bcf90e58f3ab81d3827e7fc572319908db9fb Mon Sep 17 00:00:00 2001 From: Matthew Hambrecht <14303543+matthambrecht@users.noreply.github.com> Date: Wed, 22 Oct 2025 09:55:26 -0400 Subject: [PATCH 08/16] fix: enable_sleep_mode -> vllm_enable_sleep_mode (#3225) Co-authored-by: Matthew Hambrecht --- src/axolotl/core/trainers/grpo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index d1a6b7fd9..bd77489eb 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -52,7 +52,7 @@ class GRPOStrategy: if trl.vllm_mode: grpo_args_kwargs["vllm_mode"] = trl.vllm_mode if trl.vllm_mode == "colocate": - grpo_args_kwargs["enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined] + grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined] grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( vllm_cfg.gpu_memory_utilization ) From 3750fdcf79313f5c626d9508c72ea167f7da2985 Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Wed, 22 Oct 2025 07:22:14 -0700 Subject: [PATCH 09/16] Fix trainer dataloader slow loading issue (#3219) * Fix trainer dataloader handling in src/axolotl/core/trainers/base.py * update comment to reflect torch version --------- Co-authored-by: Wing Lian --- setup.py | 2 +- src/axolotl/core/trainers/base.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index b2eeb92d6..a93d8d49e 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def parse_requirements(extras_require_map): try: torch_version = version("torch") except PackageNotFoundError: - torch_version = "2.6.0" # default to torch 2.6 + torch_version = "2.8.0" # default to torch 2.8.0 _install_requires.append(f"torch=={torch_version}") version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 11dfecb98..7d7420fb8 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -225,17 +225,6 @@ class AxolotlTrainer( data_collator = self.data_collator if is_training else self.eval_data_collator - if dataset.column_names and "length" in dataset.column_names: - dataset = dataset.remove_columns(["length"]) - if ( - dataset.column_names - and "position_ids" in dataset.column_names - and "attention_mask" in dataset.column_names - and self.args.sample_packing - and self.args.sample_packing_drop_attention_mask - ): - dataset = dataset.remove_columns(["attention_mask"]) - if isinstance(dataset, datasets.Dataset): if is_training: if not self.args.sample_packing or self.args.pretraining: @@ -294,6 +283,18 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False + if dataset.column_names and "length" in dataset.column_names: + dataset = dataset.remove_columns(["length"]) + + if ( + dataset.column_names + and "position_ids" in dataset.column_names + and "attention_mask" in dataset.column_names + and self.args.sample_packing + and self.args.sample_packing_drop_attention_mask + ): + dataset = dataset.remove_columns(["attention_mask"]) + dataloader = DataLoader(dataset, **dataloader_params) # Accelerator.free_memory() will destroy the references, so From 243620394a2576db507b1f6ab033c4183a18233e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 23 Oct 2025 05:23:20 +0700 Subject: [PATCH 10/16] fix: force train split for json,csv,txt for test_datasets and misc doc changes (#3226) * fix: force train split for json,csv,txt for test_datasets * feat(doc): add info on mixing datasets for VLM * feat(doc): max memory * fix(doc): clarify lr groups * fix: add info on vision not being dropped * feat: add qwen3-vl to multimodal docs * fix: add moe blocks to arch list * feat(doc): improve mistral docs * chore: add helpful link [skip-e2e] * fix: add vram usage for mistral small * Update link in docs/faq.qmd Co-authored-by: salman --------- Co-authored-by: Wing Lian Co-authored-by: salman --- docs/faq.qmd | 8 +++ docs/lr_groups.qmd | 6 +++ docs/multimodal.qmd | 14 ++++- examples/magistral/think/README.md | 2 +- examples/magistral/vision/README.md | 2 +- examples/mistral/mistral-small/README.md | 51 +++++++++++++++++++ .../mistral-small-3.1-24B-lora.yml | 2 +- src/axolotl/common/architectures.py | 2 + src/axolotl/utils/data/shared.py | 5 ++ 9 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 examples/mistral/mistral-small/README.md diff --git a/docs/faq.qmd b/docs/faq.qmd index ffc29d35d..92b432f2d 100644 --- a/docs/faq.qmd +++ b/docs/faq.qmd @@ -63,6 +63,14 @@ description: Frequently asked questions > A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717. +**Q: Can we mix text and text+image datasets for VLM training?** + +> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know! + +**Q: Why is `memory/max_*` different from `nvidia-smi`?** + +> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information. + ### Chat templates **Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`** diff --git a/docs/lr_groups.qmd b/docs/lr_groups.qmd index 52059016c..ce5350722 100644 --- a/docs/lr_groups.qmd +++ b/docs/lr_groups.qmd @@ -27,3 +27,9 @@ learning_rate: 2e-5 In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's self attention `q_proj` module. + +::: {.callout-note} + +We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17 + +::: diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 3a28b579a..1c4e28ea7 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -56,10 +56,14 @@ image_resize_algorithm: bilinear Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs. -::: {.callout-warning} +::: {.callout-tip} Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs. ::: +::: {.callout-note} +As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this. +::: + ### Mllama {#sec-mllama} ```yaml @@ -168,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct chat_template: qwen2_vl # same as qwen2-vl ``` +### Qwen3-VL {#sec-qwen3-vl} + +```yaml +base_model: Qwen/Qwen3-VL-4B-Instruct + +chat_template: qwen2_vl # same as qwen2-vl +``` + ### SmolVLM2 {#sec-smolvlm2} ::: {.callout-tip} diff --git a/examples/magistral/think/README.md b/examples/magistral/think/README.md index 29950f59e..a87579775 100644 --- a/examples/magistral/think/README.md +++ b/examples/magistral/think/README.md @@ -12,7 +12,7 @@ Before starting, ensure you have: Run the thinking model fine-tuning: ```bash -axolotl train magistral-small-think-qlora.yaml +axolotl train examples/magistral/think/magistral-small-think-qlora.yaml ``` This config uses about 19.1 GiB VRAM. diff --git a/examples/magistral/vision/README.md b/examples/magistral/vision/README.md index 932a3631e..fc614c850 100644 --- a/examples/magistral/vision/README.md +++ b/examples/magistral/vision/README.md @@ -21,7 +21,7 @@ Before starting, ensure you have: 3. Run the fine-tuning: ```bash - axolotl train magistral-small-vision-24B-qlora.yml + axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml ``` This config uses about 17GiB VRAM. diff --git a/examples/mistral/mistral-small/README.md b/examples/mistral/mistral-small/README.md new file mode 100644 index 000000000..3c606a897 --- /dev/null +++ b/examples/mistral/mistral-small/README.md @@ -0,0 +1,51 @@ +# Mistral Small 3.1/3.2 Fine-tuning + +This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html)) + +## Getting Started + +1. Install the required vision lib: + ```bash + pip install 'mistral-common[opencv]==1.8.5' + ``` + +2. Download the example dataset image: + ```bash + wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg + ``` + +3. Run the fine-tuning: + ```bash + axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml + ``` + +This config uses about 29.4 GiB VRAM. + +## Dataset Format + +The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now. + +Example: +```json +{ + "messages": [ + {"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]}, + {"role": "user", "content": [ + { "type": "text", "text": "What's in this image?"}, + {"type": "image", "path": "path/to/image.jpg" } + ]}, + {"role": "assistant", "content": [{ "type": "text", "text": "..." }]}, + ], +} +``` + +## Limitations + +- Sample Packing is not supported for multi-modality training currently. diff --git a/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml index ec197f333..d45d13ac6 100644 --- a/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml @@ -39,7 +39,7 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 1 -micro_batch_size: 1 +micro_batch_size: 2 num_epochs: 1 optimizer: adamw_bnb_8bit lr_scheduler: cosine diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index b754e56ba..c8a2f0836 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -12,7 +12,9 @@ MOE_ARCH_BLOCK = { "mixtral": "MixtralSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", + "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", + "deepseek_v3": "DeepseekV3MoE", "gpt_oss": "GptOssDecoderLayer", "lfm2_moe": "Lfm2MoeSparseMoeBlock", } diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index c9a91b829..a8ed55ae2 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -239,6 +239,11 @@ def _load_from_local_path( return load_dataset(dataset_config.path, **load_dataset_kwargs) elif local_path.is_file(): dataset_type = get_dataset_type(dataset_config) + + # For single file datasets, HF always creates only a "train" split + if dataset_type in ("json", "csv", "text"): + load_dataset_kwargs["split"] = "train" + return load_dataset( dataset_type, data_files=dataset_config.path, From 4dc018992dccba6fa5e239d0453cbbd565e47e96 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 23 Oct 2025 07:46:55 +0530 Subject: [PATCH 11/16] Feat/opentelemetry (#3215) --- examples/llama-3/opentelemetry-qlora.yml | 50 +++ setup.py | 6 + src/axolotl/core/builders/base.py | 12 +- src/axolotl/utils/__init__.py | 7 + src/axolotl/utils/callbacks/opentelemetry.py | 238 +++++++++++++ src/axolotl/utils/schemas/config.py | 2 + src/axolotl/utils/schemas/integrations.py | 24 ++ tests/test_opentelemetry_callback.py | 349 +++++++++++++++++++ 8 files changed, 687 insertions(+), 1 deletion(-) create mode 100644 examples/llama-3/opentelemetry-qlora.yml create mode 100644 src/axolotl/utils/callbacks/opentelemetry.py create mode 100644 tests/test_opentelemetry_callback.py diff --git a/examples/llama-3/opentelemetry-qlora.yml b/examples/llama-3/opentelemetry-qlora.yml new file mode 100644 index 000000000..d8ce7b1ec --- /dev/null +++ b/examples/llama-3/opentelemetry-qlora.yml @@ -0,0 +1,50 @@ +base_model: NousResearch/Llama-3.2-1B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_4bit: true + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca + +output_dir: ./outputs/opentelemetry-example + +adapter: qlora +sequence_len: 512 +sample_packing: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +# OpenTelemetry Configuration +use_otel_metrics: true +otel_metrics_host: "localhost" +otel_metrics_port: 8000 + +# Disable WandB +use_wandb: false + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: false + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 + +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/setup.py b/setup.py index a93d8d49e..9e3de48b5 100644 --- a/setup.py +++ b/setup.py @@ -159,6 +159,12 @@ extras_require = { "llmcompressor==0.5.1", ], "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], + "opentelemetry": [ + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-prometheus", + "prometheus-client", + ], } install_requires, dependency_links, extras_require_build = parse_requirements( extras_require diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 8c86e335e..2c949f8e7 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr -from axolotl.utils import is_comet_available, is_mlflow_available +from axolotl.utils import ( + is_comet_available, + is_mlflow_available, + is_opentelemetry_available, +) from axolotl.utils.callbacks import ( GCCallback, SaveAxolotlConfigtoWandBCallback, @@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_otel_metrics and is_opentelemetry_available(): + from axolotl.utils.callbacks.opentelemetry import ( + OpenTelemetryMetricsCallback, + ) + + callbacks.append(OpenTelemetryMetricsCallback(self.cfg)) if self.cfg.save_first_step: callbacks.append(SaveModelOnFirstStepCallback()) diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 7256a5700..72f8173f3 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -17,6 +17,13 @@ def is_comet_available(): return importlib.util.find_spec("comet_ml") is not None +def is_opentelemetry_available(): + return ( + importlib.util.find_spec("opentelemetry") is not None + and importlib.util.find_spec("prometheus_client") is not None + ) + + def get_pytorch_version() -> tuple[int, int, int]: """ Get Pytorch version as a tuple of (major, minor, patch). diff --git a/src/axolotl/utils/callbacks/opentelemetry.py b/src/axolotl/utils/callbacks/opentelemetry.py new file mode 100644 index 000000000..3f7e56b78 --- /dev/null +++ b/src/axolotl/utils/callbacks/opentelemetry.py @@ -0,0 +1,238 @@ +"""OpenTelemetry metrics callback for Axolotl training""" + +import threading +from typing import Dict, Optional + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +try: + from opentelemetry import metrics + from opentelemetry.exporter.prometheus import PrometheusMetricReader + from opentelemetry.metrics import set_meter_provider + from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider + from prometheus_client import start_http_server + + OPENTELEMETRY_AVAILABLE = True +except ImportError: + LOG.warning("OpenTelemetry not available. pip install [opentelemetry]") + OPENTELEMETRY_AVAILABLE = False + + +class OpenTelemetryMetricsCallback(TrainerCallback): + """ + TrainerCallback that exports training metrics to OpenTelemetry/Prometheus. + + This callback automatically tracks key training metrics including: + - Training loss + - Evaluation loss + - Learning rate + - Epoch progress + - Global step count + - Gradient norm + + Metrics are exposed via HTTP endpoint for Prometheus scraping. + """ + + def __init__(self, cfg): + if not OPENTELEMETRY_AVAILABLE: + LOG.warning("OpenTelemetry not available, metrics will not be collected") + self.metrics_enabled = False + return + + self.cfg = cfg + self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost") + self.metrics_port = getattr(cfg, "otel_metrics_port", 8000) + self.metrics_enabled = True + self.server_started = False + self.metrics_lock = threading.Lock() + + try: + # Create Prometheus metrics reader + prometheus_reader = PrometheusMetricReader() + + # Create meter provider with Prometheus exporter + provider = SDKMeterProvider(metric_readers=[prometheus_reader]) + set_meter_provider(provider) + + # Get meter for creating metrics + self.meter = metrics.get_meter("axolotl.training") + + # Create metrics + self._create_metrics() + + except Exception as e: + LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}") + self.metrics_enabled = False + + def _create_metrics(self): + """Create all metrics that will be tracked""" + self.train_loss_gauge = self.meter.create_gauge( + name="axolotl_train_loss", + description="Current training loss", + unit="1", + ) + + self.eval_loss_gauge = self.meter.create_gauge( + name="axolotl_eval_loss", + description="Current evaluation loss", + unit="1", + ) + + self.learning_rate_gauge = self.meter.create_gauge( + name="axolotl_learning_rate", + description="Current learning rate", + unit="1", + ) + + self.epoch_gauge = self.meter.create_gauge( + name="axolotl_epoch", + description="Current training epoch", + unit="1", + ) + + self.global_step_counter = self.meter.create_counter( + name="axolotl_global_steps", + description="Total training steps completed", + unit="1", + ) + + self.grad_norm_gauge = self.meter.create_gauge( + name="axolotl_gradient_norm", + description="Gradient norm", + unit="1", + ) + + self.memory_usage_gauge = self.meter.create_gauge( + name="axolotl_memory_usage", + description="Current memory usage in MB", + unit="MB", + ) + + def _start_metrics_server(self): + """Start the HTTP server for metrics exposure""" + if self.server_started: + return + + try: + start_http_server(self.metrics_port, addr=self.metrics_host) + self.server_started = True + LOG.info( + f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics" + ) + + except Exception as e: + LOG.error(f"Failed to start OpenTelemetry metrics server: {e}") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the beginning of training""" + if not self.metrics_enabled: + return + + self._start_metrics_server() + LOG.info("OpenTelemetry metrics collection started") + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Optional[Dict[str, float]] = None, + **kwargs, + ): + """Called when logging occurs""" + if not self.metrics_enabled or not logs: + return + + if "loss" in logs: + self.train_loss_gauge.set(logs["loss"]) + + if "eval_loss" in logs: + self.eval_loss_gauge.set(logs["eval_loss"]) + + if "learning_rate" in logs: + self.learning_rate_gauge.set(logs["learning_rate"]) + + if "epoch" in logs: + self.epoch_gauge.set(logs["epoch"]) + + if "grad_norm" in logs: + self.grad_norm_gauge.set(logs["grad_norm"]) + if "memory_usage" in logs: + self.memory_usage_gauge.set(logs["memory_usage"]) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the end of each training step""" + if not self.metrics_enabled: + return + + # Update step counter and epoch + self.global_step_counter.add(1) + if state.epoch is not None: + self.epoch_gauge.set(state.epoch) + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + metrics: Optional[Dict[str, float]] = None, + **kwargs, + ): + """Called after evaluation""" + if not self.metrics_enabled or not metrics: + return + + if "eval_loss" in metrics: + self.eval_loss_gauge.set(metrics["eval_loss"]) + + # Record any other eval metrics as gauges + for key, value in metrics.items(): + if key.startswith("eval_") and isinstance(value, (int, float)): + # Create gauge for this metric if it doesn't exist + gauge_name = f"axolotl_{key}" + try: + gauge = self.meter.create_gauge( + name=gauge_name, + description=f"Evaluation metric: {key}", + unit="1", + ) + gauge.set(value) + except Exception as e: + LOG.warning(f"Failed to create/update metric {gauge_name}: {e}") + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the end of training""" + if not self.metrics_enabled: + return + + LOG.info("Training completed. OpenTelemetry metrics collection finished.") + LOG.info( + f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics" + ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4d1d0aab2..86b3aa17b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -30,6 +30,7 @@ from axolotl.utils.schemas.integrations import ( GradioConfig, LISAConfig, MLFlowConfig, + OpenTelemetryConfig, RayConfig, WandbConfig, ) @@ -60,6 +61,7 @@ class AxolotlInputConfig( WandbConfig, MLFlowConfig, CometConfig, + OpenTelemetryConfig, LISAConfig, GradioConfig, RayConfig, diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 7332c7d39..97d675569 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -176,3 +176,27 @@ class RayConfig(BaseModel): "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." }, ) + + +class OpenTelemetryConfig(BaseModel): + """OpenTelemetry configuration subset""" + + use_otel_metrics: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Enable OpenTelemetry metrics collection and Prometheus export" + }, + ) + otel_metrics_host: str | None = Field( + default="localhost", + json_schema_extra={ + "title": "OpenTelemetry Metrics Host", + "description": "Host to bind the OpenTelemetry metrics server to", + }, + ) + otel_metrics_port: int | None = Field( + default=8000, + json_schema_extra={ + "description": "Port for the Prometheus metrics HTTP server" + }, + ) diff --git a/tests/test_opentelemetry_callback.py b/tests/test_opentelemetry_callback.py new file mode 100644 index 000000000..294ff6585 --- /dev/null +++ b/tests/test_opentelemetry_callback.py @@ -0,0 +1,349 @@ +"""Tests for OpenTelemetry metrics callback functionality.""" + +import time + +import pytest + +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_otel_config(): + """Mock configuration for OpenTelemetry callback.""" + return DictDefault( + { + "use_otel_metrics": True, + "otel_metrics_host": "localhost", + "otel_metrics_port": 8003, # Use unique port for tests + } + ) + + +@pytest.fixture +def mock_trainer_state(): + """Mock trainer state for callback testing.""" + from transformers import TrainerState + + state = TrainerState() + state.epoch = 1.0 + state.global_step = 100 + return state + + +@pytest.fixture +def mock_training_args(): + """Mock training arguments for callback testing.""" + from transformers import TrainingArguments + + return TrainingArguments(output_dir="/tmp/test") + + +@pytest.fixture +def mock_trainer_control(): + """Mock trainer control for callback testing.""" + from transformers.trainer_callback import TrainerControl + + return TrainerControl() + + +class TestOpenTelemetryConfig: + """Test OpenTelemetry configuration schema.""" + + def test_config_schema_valid(self): + """Test OpenTelemetry configuration schema validation.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test valid config + valid_config = { + "use_otel_metrics": True, + "otel_metrics_host": "localhost", + "otel_metrics_port": 8000, + } + + otel_config = OpenTelemetryConfig(**valid_config) + assert otel_config.use_otel_metrics is True + assert otel_config.otel_metrics_host == "localhost" + assert otel_config.otel_metrics_port == 8000 + + def test_config_defaults(self): + """Test OpenTelemetry configuration default values.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test minimal config with defaults + minimal_config = {"use_otel_metrics": True} + + otel_config = OpenTelemetryConfig(**minimal_config) + assert otel_config.use_otel_metrics is True + assert otel_config.otel_metrics_host == "localhost" # default + assert otel_config.otel_metrics_port == 8000 # default + + def test_config_disabled_by_default(self): + """Test that OpenTelemetry is disabled by default.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test default config + default_config = OpenTelemetryConfig() + assert default_config.use_otel_metrics is False + + +class TestOpenTelemetryCallback: + """Test OpenTelemetry callback functionality.""" + + def test_callback_import(self): + """Test that OpenTelemetry callback can be imported.""" + from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback + + assert OpenTelemetryMetricsCallback is not None + + def test_callback_graceful_fallback(self, mock_otel_config): + """Test callback gracefully handles missing dependencies.""" + from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback + + # This should not raise an exception even if dependencies are missing + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + # Callback should exist but may have metrics disabled + assert callback is not None + assert hasattr(callback, "metrics_enabled") + + def test_callback_initialization_enabled(self, mock_otel_config): + """Test callback initialization when OpenTelemetry is available.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + if OPENTELEMETRY_AVAILABLE: + assert callback.metrics_enabled is True + assert callback.cfg == mock_otel_config + assert callback.metrics_host == "localhost" + assert callback.metrics_port == 8003 + else: + assert callback.metrics_enabled is False + + def test_metrics_server_lifecycle( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test metrics server starts and stops correctly.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + # Start server + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + assert callback.server_started is True + + # End training + callback.on_train_end( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + def test_metrics_recording( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test that metrics are recorded during training.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Test logging metrics + test_logs = { + "loss": 0.5, + "learning_rate": 1e-4, + "grad_norm": 0.8, + } + + # This should not raise an exception + callback.on_log( + mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs + ) + assert callback.metrics_enabled is True + + def test_evaluation_metrics( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test evaluation metrics recording.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Test evaluation metrics + eval_logs = { + "eval_loss": 0.3, + "eval_accuracy": 0.95, + } + + # This should not raise an exception + callback.on_evaluate( + mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs + ) + assert callback.metrics_enabled is True + + def test_thread_safety(self, mock_otel_config): + """Test that callback has thread safety mechanisms.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + assert hasattr(callback, "metrics_lock") + # Check it's a lock-like object + assert hasattr(callback.metrics_lock, "__enter__") + assert hasattr(callback.metrics_lock, "__exit__") + + +class TestOpenTelemetryIntegration: + """Integration tests for OpenTelemetry.""" + + def test_availability_check(self): + """Test availability check function.""" + from axolotl.utils import is_opentelemetry_available + + result = is_opentelemetry_available() + assert isinstance(result, bool) + + def test_prometheus_endpoint_basic( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test basic Prometheus endpoint functionality.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + try: + import requests + except ImportError: + pytest.skip("requests library not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + if not callback.server_started: + pytest.skip("Metrics server failed to start") + + # Give server time to start + time.sleep(1) + + # Try to access metrics endpoint + try: + response = requests.get( + f"http://{callback.metrics_host}:{callback.metrics_port}/metrics", + timeout=2, + ) + assert response.status_code == 200 + # Check for Prometheus format + assert "# TYPE" in response.text or "# HELP" in response.text + except requests.exceptions.RequestException: + pytest.skip( + "Could not connect to metrics endpoint - this is expected in some environments" + ) + + +class TestOpenTelemetryCallbackMethods: + """Test specific callback methods.""" + + def test_step_end_callback( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test step end callback method.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Should not raise an exception + callback.on_step_end( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + def test_epoch_end_callback( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test epoch end callback method.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Should not raise an exception + callback.on_epoch_end( + mock_training_args, mock_trainer_state, mock_trainer_control + ) From bb33fda44d8cc889230698539b8df5a7ba114b67 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 22 Oct 2025 21:24:52 -0700 Subject: [PATCH 12/16] install flash attention in 2.9.0 base images (#3224) --- docker/Dockerfile-base | 6 ++++-- docker/Dockerfile-uv-base | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 87918cc41..cc209f304 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -47,6 +47,8 @@ RUN git lfs install --skip-repo && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 cache purge -RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \ - FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \ +RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ + wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ fi diff --git a/docker/Dockerfile-uv-base b/docker/Dockerfile-uv-base index eaa49b9e9..2ca272c6e 100644 --- a/docker/Dockerfile-uv-base +++ b/docker/Dockerfile-uv-base @@ -34,3 +34,9 @@ RUN uv pip install packaging setuptools wheel psutil \ && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install awscli pydantic + +RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ + wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + fi From 9d4d39e939b3e44298f0c5e1f1b05c7b515fc7a6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 27 Oct 2025 03:42:01 -0400 Subject: [PATCH 13/16] Diffusion trainer fix: shift logits to align with input tokens (#3191) * shift logits for diffusion generate * delete unused * diffusion trainer: token shift --- src/axolotl/integrations/diffusion/generation.py | 4 ++-- src/axolotl/integrations/diffusion/trainer.py | 4 ++-- src/axolotl/integrations/diffusion/utils.py | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py index 49e3cdfae..ec517fd23 100644 --- a/src/axolotl/integrations/diffusion/generation.py +++ b/src/axolotl/integrations/diffusion/generation.py @@ -7,7 +7,7 @@ import torch from axolotl.utils.logging import get_logger -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -360,7 +360,7 @@ def _diffusion_step( # Forward pass outputs = model(input_ids=sequence, attention_mask=attention_mask) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) # Only sample at currently masked positions if current_mask.any(): diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 42b2468f4..dfaef2a48 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from .callbacks import DiffusionGenerationCallback -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer): input_ids=noisy_batch.long(), attention_mask=bidirectional_mask, ) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) if masked_indices.sum() > 0: valid_indices = torch.where(masked_indices) diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py index 47abf6fec..b6f71c07b 100644 --- a/src/axolotl/integrations/diffusion/utils.py +++ b/src/axolotl/integrations/diffusion/utils.py @@ -157,3 +157,10 @@ def create_bidirectional_attention_mask( # Add head dimension: [batch_size, 1, seq_len, seq_len] return bidirectional_mask.unsqueeze(1) + + +def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor: + """Align next-token logits with their input token positions for diffusion.""" + if logits.size(1) <= 1: + return logits + return torch.cat([logits[:, :1], logits[:, :-1]], dim=1) From 98333e639a35bd36a108786a6daaa42f03488aca Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Oct 2025 18:02:16 -0400 Subject: [PATCH 14/16] upgrade trl to 0.24.0 and liger to 0.6.3 (#3230) * upgrade trl to 0.24.0 * fix reward collator init * use newer DataCollatorForPreference instead * DataCollatorForPreference doesn't use padding kwarg * fix input id labels * fix fbgemm-gpu version for pytorch versions * tweak pinned deps * transformers doesn't support hub 1.0 yet * upgrade liger dep to 0.6.3 * set TORCH_CUDA_ARCH_LIST correctly --- cicd/Dockerfile.jinja | 2 +- requirements.txt | 12 ++++++------ setup.py | 8 ++++++-- src/axolotl/core/builders/causal.py | 9 ++++++--- .../prompt_strategies/bradley_terry/chat_template.py | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 6a1ddb66d..c3a613ecc 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -1,6 +1,6 @@ FROM axolotlai/axolotl-base:{{ BASE_TAG }} -ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" +ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}" ENV CUDA="{{ CUDA }}" diff --git a/requirements.txt b/requirements.txt index e1f1b10a5..5621d94b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,27 +5,27 @@ bitsandbytes==0.47.0 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 -liger-kernel==0.6.1 +liger-kernel==0.6.3 # END section packaging==23.2 -huggingface_hub>=0.33.0 +huggingface_hub>=0.36.0 peft>=0.17.1 tokenizers>=0.21.1 transformers==4.57.1 accelerate==1.10.1 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.23.1 -hf_xet==1.1.5 -kernels==0.9.0 +trl==0.24.0 +hf_xet==1.2.0 +kernels>=0.9.0 trackio optimum==1.16.2 hf_transfer sentencepiece -gradio==5.41.1 +gradio==5.49.1 modal==1.0.2 pydantic==2.10.6 diff --git a/setup.py b/setup.py index 9e3de48b5..2845bb151 100644 --- a/setup.py +++ b/setup.py @@ -62,8 +62,12 @@ def parse_requirements(extras_require_map): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 8): - pass + if (major, minor) >= (2, 9): + extras_require_map.pop("fbgemm-gpu") + extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"] + elif (major, minor) >= (2, 8): + extras_require_map.pop("fbgemm-gpu") + extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] elif (major, minor) >= (2, 7): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 820304230..7a06431dc 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -12,7 +12,7 @@ from transformers import ( EarlyStoppingCallback, Trainer, ) -from trl.trainer.utils import RewardDataCollatorWithPadding +from trl.trainer.reward_trainer import DataCollatorForPreference from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.trainers import ( @@ -453,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, - RewardDataCollatorWithPadding, + DataCollatorForPreference, ] ] collator_args = [self.tokenizer] @@ -470,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if kwargs and isinstance(kwargs, dict): kwargs.update(collator_cls_and_kwargs[1]) elif self.cfg.reward_model: - collator = RewardDataCollatorWithPadding + collator = DataCollatorForPreference + tokenizer = collator_args.pop(0) + kwargs["pad_token_id"] = tokenizer.pad_token_id + kwargs.pop("padding") elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index fd0d76f51..03336b3ef 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): ] return { - "input_ids_chosen": chosen_tokenized["input_ids"], + "chosen_input_ids": chosen_tokenized["input_ids"], "attention_mask_chosen": chosen_tokenized["attention_mask"], "labels_chosen": 1.0, - "input_ids_rejected": rejected_tokenized["input_ids"], + "rejected_input_ids": rejected_tokenized["input_ids"], "attention_mask_rejected": rejected_tokenized["attention_mask"], "labels_rejected": 0.0, } From a4b921135b56abad32f962009686b52089b273c9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Oct 2025 18:07:29 -0400 Subject: [PATCH 15/16] build cuda 13.0.0 base image with 2.9.0 (#3229) * build cuda 13.0.0 base image with 2.9.0 * upgrade causal-conv1d * 1.5.4 not in pypi yet * pin to 1.3.0 * use github release instead of pypi * split the logic for incompatible packages * fix bash in dockerfile --- .github/workflows/base.yml | 14 ++++++++++++++ docker/Dockerfile-base | 8 ++++++-- setup.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index b2681bb5d..87d6772dd 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -60,6 +60,13 @@ jobs: pytorch: 2.9.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" + - cuda: "130" + cuda_version: 13.0.0 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.9.0 + torch_cuda_arch_list: "9.0+PTX" + dockerfile: "Dockerfile-base" # - cuda: "128" # cuda_version: 12.8.1 # cudnn_version: "" @@ -143,6 +150,13 @@ jobs: pytorch: 2.9.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" + - cuda: "130" + cuda_version: 13.0.0 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.9.0 + torch_cuda_arch_list: "9.0+PTX" + dockerfile: "Dockerfile-uv-base" steps: - name: Checkout uses: actions/checkout@v4 diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index cc209f304..a08b5cd4f 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -37,10 +37,14 @@ WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ - CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ - python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ python3 -m pip cache purge +RUN if [ "$CUDA" != "130" ] ; then \ + CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \ + python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \ + python3 -m pip cache purge; \ + fi + RUN git lfs install --skip-repo && \ pip3 install awscli && \ # The base image ships with `pydantic==1.8.2` which is not working diff --git a/setup.py b/setup.py index 2845bb151..b16377e92 100644 --- a/setup.py +++ b/setup.py @@ -162,7 +162,7 @@ extras_require = { "llmcompressor": [ "llmcompressor==0.5.1", ], - "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], + "fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"], "opentelemetry": [ "opentelemetry-api", "opentelemetry-sdk", From 0f7c886b7b28a0a90a8510c58f160f6ee70e9851 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 18:09:46 -0400 Subject: [PATCH 16/16] chore: update pre-commit hooks (#3222) [skip ci] Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e455f52c..015fb5e6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.0 + rev: v0.14.2 hooks: - id: ruff args: [--fix]