Compare commits

...

7 Commits

Author SHA1 Message Date
Wing Lian
c6d69d5c1b release v0.11.0 (#2875)
Some checks failed
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* release v0.11.0

* don't build vllm into release for now

* remove 2.5.1 references

* smollm3 multipack support

* fix ordering of e2e tests
2025-07-09 09:22:35 -04:00
Wing Lian
4ff96a2526 fix xformers version (#2888) 2025-07-09 08:43:40 -04:00
salman
89e99eaaa7 slowest durations (#2887) [skip ci] 2025-07-09 08:43:26 -04:00
Wing Lian
6ed501f6dc add 2.7.0 torch images back to support vlllm (#2885) 2025-07-08 16:28:14 -04:00
NanoCode012
8c6a6ea6eb Feat: add devstral model support (#2880) [skip ci]
* fix: do not add training and training_detail block by default

* fixed: magistral docs

* fix: address pad adding new fields and use built-in from_openai

* feat: try enable multiprocessing

* fix: check for keys before deleting attn_mask

* feat: add mistral pad test

* feat: add tool calling test

* feat: add devstral tokenizer tests

* fix: comma format

* chore: remove unused support_preprocessing as tokenizer is pickable now

* chore: update magistral doc

* feat: add devstral readme and example

* chore: refactor error handling
2025-07-08 11:01:19 -04:00
NanoCode012
78bff4925e fix: set add_generation_prompt to False when apply chat template (#2859) [skip ci] 2025-07-08 11:00:44 -04:00
NanoCode012
b237c8a3f3 chore: update cce commit to include gemma3n fixes (#2881) [skip ci] 2025-07-08 10:59:35 -04:00
32 changed files with 751 additions and 272 deletions

View File

@@ -29,11 +29,11 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124"
cuda_version: 12.4.1
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
@@ -43,7 +43,7 @@ jobs:
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"

View File

@@ -15,15 +15,15 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
@@ -82,17 +82,17 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -12,11 +12,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -68,10 +63,10 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.6.0
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -80,9 +80,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |

View File

@@ -52,7 +52,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -102,9 +102,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -125,7 +125,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -175,9 +175,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
@@ -198,7 +198,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
@@ -252,18 +252,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1

View File

@@ -55,7 +55,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.5.1
- PyTorch ≥2.6.0
### Installation

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),

View File

@@ -36,7 +36,6 @@ Tags examples:
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
## Main
@@ -78,10 +77,9 @@ Tags examples:
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.10.1`
## Cloud

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.5.1
- PyTorch ≥2.6.0
## Installation Methods {#sec-installation-methods}

View File

@@ -0,0 +1,69 @@
# Finetune Devstral with Axolotl
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install the latest mistral-common from source
pip3 uninstall mistral-common
pip3 install git+https://github.com/mistralai/mistral-common.git@039465d
```
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml
```
This config uses about 21GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,64 @@
base_model: mistralai/Devstral-Small-2505
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
load_in_8bit: false
load_in_4bit: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.05
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
@@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
@@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
In addition, we do not support overriding tokens yet.
## Related Resources

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
)

View File

@@ -66,8 +66,11 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
if patch == 0:
_install_requires.append("xformers==0.0.30")
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.11.0.dev"
__version__ = "0.11.0"

View File

@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
```
## Usage

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
)

View File

@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # torch>=2.5.1, recommended to reduce vram
torch_compile: True # torch>=2.6.0, recommended to reduce vram
datasets:
- path: ...

View File

@@ -35,6 +35,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"deepseek_v3",
"glm",
"glm4",
"smollm3",
]

View File

@@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
for message in messages:
transformed_message = self.transform_message(message)
turn = {
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
turn = transformed_message
training = message.get(self.prompter.message_field_training)
training_detail = message.get(self.prompter.message_field_training_detail)
if training is not None:
turn["training"] = training
if training_detail is not None:
turn["training_detail"] = training_detail
turns.append(turn)
@@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy):
# TODO: address this in the future with mistral-specific checks
# self._validate_eot_and_eos_tokens()
@property
def supports_multiprocessing(self) -> bool:
"""
Whether this tokenizing strategy supports multiprocessing.
mistral_common tokenizers cannot be pickled for multiprocessing.
"""
return False
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# mistral-common tokenizer does not support eot_tokens

View File

@@ -70,14 +70,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@property
def supports_multiprocessing(self):
"""
Whether this tokenizing strategy supports multiprocessing.
Should return False if the tokenizer has unpicklable objects.
"""
return True
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:

View File

@@ -108,7 +108,7 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
if not has_attn_mask and "attention_mask" in features:
del features["attention_mask"]
# prepare decoder_input_ids

View File

@@ -50,7 +50,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=True,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,

View File

@@ -3,10 +3,11 @@
import math
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional
from typing import Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor
@@ -14,9 +15,6 @@ from transformers.utils import PaddingStrategy
from axolotl.utils.collators.core import IGNORE_INDEX
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
"""Get the file path from local or HF Hub"""
@@ -259,75 +257,6 @@ class HFMistralTokenizer:
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None
) -> "ChatCompletionRequest":
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
[]
)
for turn in conversation:
role = turn.get("role")
if role == "user":
messages.append(UserMessage(content=turn["content"]))
elif role == "assistant":
messages.append(
AssistantMessage(
content=turn.get("content"),
tool_calls=turn.get("tool_calls"),
)
)
elif role == "tool":
messages.append(
ToolMessage(
content=turn.get("content"),
tool_call_id=turn.get("tool_call_id"),
name=turn.get("name"),
)
)
elif role == "system":
messages.append(SystemMessage(content=turn["content"]))
else:
raise ValueError(
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
)
tool_calls: list[Tool] = []
if tools:
# convert to Tool
for tool in tools:
if tool["type"] != "function":
continue
function = tool["function"]
tool_calls.append(
Tool(
function=Function(
name=function["name"],
description=function["description"],
# set parameters to empty dict if not provided
parameters=function.get("parameters", {}),
)
)
)
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
messages=messages,
tools=tool_calls,
)
return chat_completion
def apply_chat_template(
self,
messages: list[dict],
@@ -342,8 +271,8 @@ class HFMistralTokenizer:
if add_generation_prompt:
raise NotImplementedError("add_generation_prompt not supported yet")
chat_completion: ChatCompletionRequest = (
self._create_mistral_chat_completion_request(messages, tools)
chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai(
messages, tools
)
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
@@ -408,13 +337,16 @@ class HFMistralTokenizer:
padding_value=IGNORE_INDEX,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
attention_mask = None
if "attention_mask" in features[0]:
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
position_ids = None
if "position_ids" in features[0]:
if self.padding_side == "left":
# Likely not needed, but keeping for now
@@ -443,22 +375,15 @@ class HFMistralTokenizer:
pos_seq = torch.cat([pos_seq, pad_positions])
position_ids_list.append(pos_seq)
position_ids = torch.stack(position_ids_list)
else:
# Create position_ids if not present
seq_len = input_ids.size(1)
position_ids = (
torch.arange(seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.size(0), -1)
)
# Ensure all tensors have the same sequence length
max_seq_len = max(
input_ids.size(1),
labels.size(1),
attention_mask.size(1),
position_ids.size(1),
)
# Check attention mask and position ids if they are present
tensor_lengths = [input_ids.size(1), labels.size(1)]
if attention_mask is not None:
tensor_lengths.append(attention_mask.size(1))
if position_ids is not None:
tensor_lengths.append(position_ids.size(1))
max_seq_len = max(tensor_lengths)
# TODO: check if trimming is needed? and correct.
@@ -492,44 +417,48 @@ class HFMistralTokenizer:
elif labels.size(1) > max_seq_len:
labels = labels[:, :max_seq_len]
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if attention_mask is not None:
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
if position_ids is not None:
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
final_batch = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if attention_mask is not None:
final_batch["attention_mask"] = attention_mask
if position_ids is not None:
final_batch["position_ids"] = position_ids
# Handle non-sequence fields (raise error)
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
@@ -545,7 +474,7 @@ class HFMistralTokenizer:
result = {}
for k, v in final_batch.items():
if isinstance(v, torch.Tensor):
result[k] = v.numpy().astype(np.long)
result[k] = v.numpy().astype(np.int64)
else:
result[k] = v
return result

View File

@@ -627,7 +627,7 @@ class AxolotlInputConfig(
torch_compile: Literal["auto"] | bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.5.1"
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
},
)
torch_compile_backend: str | None = Field(
@@ -1083,9 +1083,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
if version.parse(torch_version) < version.parse("2.6.0"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
f"torch=={torch_version} not be supported. Please upgrade to torch>=2.6.0."
)
return self

View File

@@ -692,7 +692,7 @@ class TestValidation(BaseValidation):
"bf16": True,
"capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
}
)
@@ -1202,7 +1202,7 @@ class TestValidation(BaseValidation):
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
env_capabilities = {"torch_version": "2.6.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
@@ -1244,7 +1244,7 @@ class TestTorchCompileValidation(BaseValidation):
| minimal_cfg
)
env_capabilities = {"torch_version": "2.5.1"}
env_capabilities = {"torch_version": "2.6.0"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities

View File

@@ -164,6 +164,14 @@ def fixture_magistral_tokenizer():
return tokenizer
@pytest.fixture(name="devstral_tokenizer")
def fixture_devstral_tokenizer():
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

View File

@@ -3,32 +3,50 @@
import unittest
from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
# fmt: off
@pytest.mark.parametrize(
("tokenizer_str", "assistant_toolcall_ids"),
(
("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2)),
("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)),
)
)
# fmt: on
def test_mistral_chat_template(
tokenizer_str: str,
assistant_toolcall_ids: tuple[int, ...],
request: pytest.FixtureRequest,
):
"""Test chat template with the Magistral/Devstral tokenizer"""
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
# check bos, eos, pad, unk are accessible properties
assert magistral_tokenizer.bos_token_id == 1
assert magistral_tokenizer.eos_token_id == 2
assert magistral_tokenizer.pad_token_id == 11
assert magistral_tokenizer.unk_token_id == 0
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
assert magistral_tokenizer.pad_token == "<pad>"
assert magistral_tokenizer.eos_token == "</s>"
assert magistral_tokenizer.bos_token == "<s>"
assert magistral_tokenizer.unk_token == "<unk>"
# check bos, eos, pad, unk are accessible properties
assert tokenizer.bos_token_id == 1
assert tokenizer.eos_token_id == 2
assert tokenizer.pad_token_id == 11
assert tokenizer.unk_token_id == 0
assert tokenizer.pad_token == "<pad>"
assert tokenizer.eos_token == "</s>"
assert tokenizer.bos_token == "<s>"
assert tokenizer.unk_token == "<unk>"
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
tokenizer=tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
@@ -219,7 +237,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
1, # bos
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
*assistant_toolcall_ids, # assistant tool calling
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
@@ -229,7 +247,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
*assistant_toolcall_ids, # assistant tool calling
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
@@ -237,7 +255,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
# fmt: on
# test chat template with tokenize=False
res = magistral_tokenizer.apply_chat_template(
res = tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
@@ -248,7 +266,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
# test encode
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True)
res = tokenizer.encode("Hello, how are you?", add_special_tokens=True)
assert res == [
1, # bos
22177, # Hello
@@ -261,16 +279,16 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
]
# test decode no skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False)
decoded_res = tokenizer.decode(res, skip_special_tokens=False)
assert decoded_res == "<s>Hello, how are you?</s>"
# test decode skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True)
decoded_res = tokenizer.decode(res, skip_special_tokens=True)
assert decoded_res == "Hello, how are you?"
# test encode no special tokens
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False)
res = tokenizer.encode("Hello, how are you?", add_special_tokens=False)
assert res == [
22177, # Hello
1044, # ,
@@ -281,10 +299,452 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
]
# test convert ids to tokens
res = magistral_tokenizer.convert_ids_to_tokens(res)
res = tokenizer.convert_ids_to_tokens(res)
# spacing are needed as we are converting without decoding
assert res == ["Hello", ",", " how", " are", " you", "?"]
def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"):
"""Test the MistralTokenizer pad method"""
from axolotl.utils.collators.core import IGNORE_INDEX
magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id
# Test padding with input_ids and labels only
features = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8], "labels": [9, 10]},
]
result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt")
# Check that input_ids are padded correctly
assert result["input_ids"].shape == (2, 3)
assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]
# Check that labels are padded correctly
assert result["labels"].shape == (2, 3)
assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]
# Check that attention_mask and position_ids are NOT created
assert "attention_mask" not in result
assert "position_ids" not in result
# Test padding with attention_mask
features_with_attention = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]},
{"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]},
]
result = magistral_tokenizer.pad(
features_with_attention, padding=True, return_tensors="pt"
)
# Check that attention_mask is padded correctly
assert result["attention_mask"].shape == (2, 3)
assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]]
# Test padding with position_ids
features_with_position = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]},
{"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]},
]
result = magistral_tokenizer.pad(
features_with_position, padding=True, return_tensors="pt"
)
# Check that position_ids are padded correctly (continuing sequence)
assert result["position_ids"].shape == (2, 3)
assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]]
# Test padding with all fields
features_all = [
{
"input_ids": [1, 2, 3],
"labels": [4, 5, 6],
"attention_mask": [1, 1, 1],
"position_ids": [0, 1, 2],
},
{
"input_ids": [7, 8],
"labels": [9, 10],
"attention_mask": [1, 1],
"position_ids": [0, 1],
},
]
result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt")
# All fields should be present and correctly padded
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "position_ids" in result
# Test padding with all sequences same length
features_same_length = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8, 9], "labels": [10, 11, 12]},
]
result = magistral_tokenizer.pad(
features_same_length, padding=True, return_tensors="pt"
)
# Check match when no padding is needed
assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"]
assert result["labels"][0].tolist() == features_same_length[0]["labels"]
assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"]
assert result["labels"][1].tolist() == features_same_length[1]["labels"]
# Test padding with max_length parameter
result = magistral_tokenizer.pad(
features, padding="max_length", max_length=5, return_tensors="pt"
)
# Should pad to max_length
assert result["input_ids"].shape == (2, 5)
assert result["labels"].shape == (2, 5)
# Test numpy return type
result = magistral_tokenizer.pad(features, padding=True, return_tensors="np")
# Should return numpy arrays
import numpy as np
assert isinstance(result["input_ids"], np.ndarray)
assert isinstance(result["labels"], np.ndarray)
# Test unsupported field rejection
features_unsupported = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]},
]
with pytest.raises(NotImplementedError, match="unsupported_field"):
magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt")
# Test token_type_ids rejection
features_token_type = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]},
]
with pytest.raises(ValueError, match="token_type_ids is not supported"):
magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt")
def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
"""Test tool calling with the Magistral tokenizer"""
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# Test basic tool calling with single function
basic_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {
"location": "San Francisco, CA",
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "get_weather",
"content": "Sunny, 72°F",
},
{
"role": "assistant",
"content": "The weather in San Francisco is sunny and 72°F.",
},
],
}
res = strategy.tokenize_prompt(basic_tool_calling)
# Basic validation
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
# Decode and verify structure
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}</s>'
in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded
assert "The weather in San Francisco is sunny and 72°F.</s>" in decoded
# Test multiple tool calls in sequence
multi_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
},
{
"type": "function",
"function": {
"name": "multiply_numbers",
"description": "Multiply two numbers",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "number", "description": "First number"},
"y": {"type": "number", "description": "Second number"},
},
"required": ["x", "y"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Add 5 and 3, then multiply the result by 2",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "add_numbers",
"arguments": {"a": 5, "b": 3},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "add_numbers",
"content": "8",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call23456",
"type": "function",
"function": {
"name": "multiply_numbers",
"arguments": {"x": 8, "y": 2},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call23456",
"name": "multiply_numbers",
"content": "16",
},
{
"role": "assistant",
"content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.",
},
],
}
res = strategy.tokenize_prompt(multi_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}</s>' in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded
assert (
'[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}</s>'
in decoded
)
assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded
assert (
"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>"
in decoded
)
# Test tool calling with system message
system_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "search_database",
"description": "Search for information in database",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
},
},
],
"messages": [
{
"role": "system",
"content": "You are a helpful assistant with access to a database.",
},
{
"role": "user",
"content": "Find information about Python programming",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "search123",
"type": "function",
"function": {
"name": "search_database",
"arguments": {"query": "Python programming"},
},
}
],
},
{
"role": "tool",
"tool_call_id": "search123",
"name": "search_database",
"content": "Python is a high-level programming language known for its simplicity.",
},
{
"role": "assistant",
"content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.",
},
],
}
res = strategy.tokenize_prompt(system_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
# Test error handling - missing tool response
incomplete_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get current time",
"parameters": {"type": "object", "properties": {}},
},
},
],
"messages": [
{
"role": "user",
"content": "What time is it?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "time12345",
"type": "function",
"function": {
"name": "get_time",
"arguments": {},
},
}
],
},
{
"role": "assistant",
"content": "The current time is 12:00 PM.",
},
],
}
from mistral_common.exceptions import InvalidMessageStructureException
try:
strategy.tokenize_prompt(incomplete_tool_calling)
except InvalidMessageStructureException as e:
assert "Not the same number of function calls and responses" in str(e)
if __name__ == "__main__":
unittest.main()

View File

@@ -73,7 +73,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
)
@@ -128,7 +128,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
)
@@ -184,7 +184,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
)
@@ -241,7 +241,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
)