From c51d6b06c3b25cfc94dd2ca08fc31b3a73ac4c29 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 19 Sep 2025 17:34:04 +0700 Subject: [PATCH] feat: add apertus model and cce (#3144) [skip ci] * feat: add apertus, glm4v, glm4v_moe cce * fix: arcee docs * feat: add apertus * feat: added vram usage * fix: add apertus note * feat: update doc on apertus xielu * fix: add monkeypatch for xielu activation issue * fix: simplify env * feat: pin commit * feat: add packing * chore: move patch calling * Update examples/apertus/README.md Co-authored-by: salman * Update examples/apertus/README.md Co-authored-by: salman * Update examples/apertus/README.md Co-authored-by: salman --------- Co-authored-by: salman --- examples/apertus/README.md | 110 ++++++++++++++++++ examples/apertus/apertus-8b-qlora.yaml | 64 ++++++++++ examples/arcee/README.md | 3 + .../colab-axolotl-example.ipynb | 2 +- scripts/cutcrossentropy_install.py | 2 +- .../integrations/cut_cross_entropy/README.md | 2 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/loaders/patch_manager.py | 12 +- .../monkeypatch/models/apertus/__init__.py | 0 .../monkeypatch/models/apertus/activation.py | 52 +++++++++ src/axolotl/monkeypatch/multipack.py | 1 + 11 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 examples/apertus/README.md create mode 100644 examples/apertus/apertus-8b-qlora.yaml create mode 100644 src/axolotl/monkeypatch/models/apertus/__init__.py create mode 100644 src/axolotl/monkeypatch/models/apertus/activation.py diff --git a/examples/apertus/README.md b/examples/apertus/README.md new file mode 100644 index 000000000..774286333 --- /dev/null +++ b/examples/apertus/README.md @@ -0,0 +1,110 @@ +# Finetune Swiss-AI's Apertus with Axolotl + +[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. (Optional, highly recommended) Install XIELU CUDA + +```bash +## Recommended for reduced VRAM and faster speeds + +# Point to CUDA toolkit directory +# For those using our Docker image, use the below path. +export CUDA_HOME=/usr/local/cuda + +pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps +``` + +For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues) + +3. Run the finetuning example: + +```bash +axolotl train examples/apertus/apertus-8b-qlora.yaml +``` + +This config uses about 8.7 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### Tips + +- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`. +- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +### XIELU Installation Issues + +#### `ModuleNotFoundError: No module named 'torch'` + +Please check these one by one: +- Running in correct environment +- Env has PyTorch installed +- CUDA toolkit is at `CUDA_HOME` + +If those didn't help, please try the below solutions: + +1. Pass env for CMAKE and try install again: + + ```bash + Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps + ``` + +2. Git clone the repo and manually hardcode python path: + + ```bash + git clone https://github.com/nickjbrowning/XIELU + cd xielu + git checkout 59d6031 + + cd xielu + nano CMakeLists.txt # or vi depending on your preference + ``` + + ```diff + execute_process( + - COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" + + COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)" + RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT + OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT + ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR + ) + ``` + + ```bash + pip3 install . --no-build-isolation --no-deps + ``` + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/apertus/apertus-8b-qlora.yaml b/examples/apertus/apertus-8b-qlora.yaml new file mode 100644 index 000000000..521b282da --- /dev/null +++ b/examples/apertus/apertus-8b-qlora.yaml @@ -0,0 +1,64 @@ +base_model: swiss-ai/Apertus-8B-Instruct-2509 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/arcee/README.md b/examples/arcee/README.md index 217893306..23f63663e 100644 --- a/examples/arcee/README.md +++ b/examples/arcee/README.md @@ -19,6 +19,9 @@ cd axolotl pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh ``` 2. Run the finetuning example: diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 774b78b82..e63632e7c 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc\"" ] }, { diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 5b49e7427..ada574805 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 393412f64..2361dde4a 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc" ``` ## Usage diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index d0eb1ebdb..dad3f7f89 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"`' ) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 98eb07b0f..a78f8b965 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -68,11 +68,12 @@ class PatchManager: self._apply_self_attention_lora_patch() self._apply_fsdp2_bnb_patches() self._apply_patch_deepspeed_zero3() + self._apply_voxtral_patches() + self._apply_apertus_patches() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" self._apply_tiled_mlp(self.cfg.model_config_type) - self._apply_voxtral_patches() def _apply_transformers_patches(self): from axolotl.monkeypatch.transformers.trainer_loss_calc import ( @@ -493,3 +494,12 @@ class PatchManager: apply_deepspeed_patches() except ImportError as e: LOG.warning(f"DeepSpeed patches not applied: {e}") + + def _apply_apertus_patches(self): + """Apply patches for Apertus model.""" + if self.cfg.model_config_type == "apertus": + from axolotl.monkeypatch.models.apertus.activation import ( + patch_apertus_xielu_activation, + ) + + patch_apertus_xielu_activation() diff --git a/src/axolotl/monkeypatch/models/apertus/__init__.py b/src/axolotl/monkeypatch/models/apertus/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/apertus/activation.py b/src/axolotl/monkeypatch/models/apertus/activation.py new file mode 100644 index 000000000..d5470aceb --- /dev/null +++ b/src/axolotl/monkeypatch/models/apertus/activation.py @@ -0,0 +1,52 @@ +"""Monkeypatch for Apertus to dtype mismatch in XIELU act""" + +from torch import Tensor + + +def patch_apertus_xielu_activation(): + try: + from transformers.activations import XIELUActivation + except ImportError as err: + raise ImportError( + "Cannot import XIELUActivation. " + "Please make sure to update your transformers version >= 4.56.1." + ) from err + + from transformers.activations import logger + + # Store the original method + old_fn = XIELUActivation._xielu_cuda + + def _xielu_cuda_fixed(self, x: Tensor) -> Tensor: + """Firewall function to prevent torch.compile from seeing .item() calls""" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p.to(x.dtype), + self.alpha_n.to(x.dtype), + # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + # Apply the patch + XIELUActivation._xielu_cuda = _xielu_cuda_fixed + + def unpatch(): + """Restore the original method""" + XIELUActivation._xielu_cuda = old_fn + + return unpatch diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index a32430d9f..726e60111 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -11,6 +11,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "apertus", "mllama_text_model", "llama", "llama4",