Compare commits
8 Commits
fix/rl-tra
...
v0.11.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6d69d5c1b | ||
|
|
4ff96a2526 | ||
|
|
89e99eaaa7 | ||
|
|
6ed501f6dc | ||
|
|
8c6a6ea6eb | ||
|
|
78bff4925e | ||
|
|
b237c8a3f3 | ||
|
|
1032e22650 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -29,11 +29,11 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "124"
|
- cuda: "126"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
@@ -43,7 +43,7 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "126"
|
- cuda: "126"
|
||||||
|
|||||||
20
.github/workflows/main.yml
vendored
20
.github/workflows/main.yml
vendored
@@ -15,15 +15,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
@@ -82,17 +82,17 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -33,13 +33,6 @@ jobs:
|
|||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
num_gpus: 2
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
11
.github/workflows/nightlies.yml
vendored
11
.github/workflows/nightlies.yml
vendored
@@ -12,11 +12,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -68,10 +63,10 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
8
.github/workflows/tests-nightly.yml
vendored
8
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
pytorch_version: ["2.6.0", "2.7.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -80,9 +80,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
30
.github/workflows/tests.yml
vendored
30
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -102,9 +102,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
@@ -125,7 +125,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -175,9 +175,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -198,7 +198,7 @@ jobs:
|
|||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
@@ -252,18 +252,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: llmcompressor
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ Features:
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python 3.11
|
||||||
- PyTorch ≥2.5.1
|
- PyTorch ≥2.6.0
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||||
"CUDA": os.environ.get("CUDA", "124"),
|
"CUDA": os.environ.get("CUDA", "126"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||||
"CUDA": os.environ.get("CUDA", "124"),
|
"CUDA": os.environ.get("CUDA", "126"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ Tags examples:
|
|||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.6.0`
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.5.1`
|
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -78,10 +77,9 @@ Tags examples:
|
|||||||
- `main-py3.11-cu126-2.7.1`
|
- `main-py3.11-cu126-2.7.1`
|
||||||
- `main-py3.11-cu126-2.6.0`
|
- `main-py3.11-cu126-2.6.0`
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-py3.11-cu124-2.5.1`
|
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu124-2.5.1`
|
- `main-20250303-py3.11-cu126-2.6.0`
|
||||||
- `0.10.1`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.5.1
|
- PyTorch ≥2.6.0
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation Methods {#sec-installation-methods}
|
||||||
|
|
||||||
|
|||||||
69
examples/devstral/README.md
Normal file
69
examples/devstral/README.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# Finetune Devstral with Axolotl
|
||||||
|
|
||||||
|
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||||
|
|
||||||
|
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install the latest mistral-common from source
|
||||||
|
pip3 uninstall mistral-common
|
||||||
|
pip3 install git+https://github.com/mistralai/mistral-common.git@039465d
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/devstral/devstral-small-qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 21GB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||||
|
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||||
|
|
||||||
|
In addition, we do not support overriding tokens yet.
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
|
|
||||||
|
|
||||||
|
## Future Work
|
||||||
|
|
||||||
|
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||||
|
- Add parity to other tokenizer configs like overriding tokens.
|
||||||
64
examples/devstral/devstral-small-qlora.yml
Normal file
64
examples/devstral/devstral-small-qlora.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: mistralai/Devstral-Small-2505
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# Enable to use mistral-common tokenizer
|
||||||
|
tokenizer_use_mistral_common: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.05
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
@@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Download the example config:
|
2. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl fetch examples
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||||
@@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
@@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||||
|
|
||||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
In addition, we do not support overriding tokens yet.
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
||||||
)
|
)
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -66,8 +66,11 @@ def parse_requirements(extras_require_map):
|
|||||||
|
|
||||||
if (major, minor) >= (2, 7):
|
if (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
if patch == 0:
|
||||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
_install_requires.append("xformers==0.0.30")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.31.post1")
|
||||||
|
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append(
|
_install_requires.append(
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.11.0.dev"
|
__version__ = "0.11.0"
|
||||||
|
|||||||
@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||||
|
|
||||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
|
||||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
|
||||||
LOG.info(
|
|
||||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
|
||||||
)
|
|
||||||
num_proc = 1
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
|
"smollm3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
transformed_message = self.transform_message(message)
|
transformed_message = self.transform_message(message)
|
||||||
|
|
||||||
turn = {
|
turn = transformed_message
|
||||||
**transformed_message,
|
|
||||||
"training": message.get(self.prompter.message_field_training),
|
training = message.get(self.prompter.message_field_training)
|
||||||
"training_detail": message.get(
|
training_detail = message.get(self.prompter.message_field_training_detail)
|
||||||
self.prompter.message_field_training_detail
|
if training is not None:
|
||||||
),
|
turn["training"] = training
|
||||||
}
|
if training_detail is not None:
|
||||||
|
turn["training_detail"] = training_detail
|
||||||
|
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
||||||
@@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy):
|
|||||||
# TODO: address this in the future with mistral-specific checks
|
# TODO: address this in the future with mistral-specific checks
|
||||||
# self._validate_eot_and_eos_tokens()
|
# self._validate_eot_and_eos_tokens()
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_multiprocessing(self) -> bool:
|
|
||||||
"""
|
|
||||||
Whether this tokenizing strategy supports multiprocessing.
|
|
||||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_first_eot_token(self, input_ids, start_idx):
|
def find_first_eot_token(self, input_ids, start_idx):
|
||||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||||
# mistral-common tokenizer does not support eot_tokens
|
# mistral-common tokenizer does not support eot_tokens
|
||||||
|
|||||||
@@ -70,14 +70,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
def supports_batched(self):
|
def supports_batched(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_multiprocessing(self):
|
|
||||||
"""
|
|
||||||
Whether this tokenizing strategy supports multiprocessing.
|
|
||||||
Should return False if the tokenizer has unpicklable objects.
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
if not has_attn_mask:
|
if not has_attn_mask and "attention_mask" in features:
|
||||||
del features["attention_mask"]
|
del features["attention_mask"]
|
||||||
|
|
||||||
# prepare decoder_input_ids
|
# prepare decoder_input_ids
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
# This method requires transformers>=4.49.0
|
# This method requires transformers>=4.49.0
|
||||||
result = self.processing_strategy.processor.apply_chat_template(
|
result = self.processing_strategy.processor.apply_chat_template(
|
||||||
example["messages"],
|
example["messages"],
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=False,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
|
|||||||
@@ -3,10 +3,11 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@@ -14,9 +15,6 @@ from transformers.utils import PaddingStrategy
|
|||||||
|
|
||||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
||||||
|
|
||||||
|
|
||||||
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
||||||
"""Get the file path from local or HF Hub"""
|
"""Get the file path from local or HF Hub"""
|
||||||
@@ -259,75 +257,6 @@ class HFMistralTokenizer:
|
|||||||
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_mistral_chat_completion_request(
|
|
||||||
self, conversation: list[dict], tools: list[dict] | None = None
|
|
||||||
) -> "ChatCompletionRequest":
|
|
||||||
from mistral_common.protocol.instruct.messages import (
|
|
||||||
AssistantMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
||||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
|
||||||
|
|
||||||
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
|
|
||||||
[]
|
|
||||||
)
|
|
||||||
for turn in conversation:
|
|
||||||
role = turn.get("role")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
messages.append(UserMessage(content=turn["content"]))
|
|
||||||
elif role == "assistant":
|
|
||||||
messages.append(
|
|
||||||
AssistantMessage(
|
|
||||||
content=turn.get("content"),
|
|
||||||
tool_calls=turn.get("tool_calls"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif role == "tool":
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(
|
|
||||||
content=turn.get("content"),
|
|
||||||
tool_call_id=turn.get("tool_call_id"),
|
|
||||||
name=turn.get("name"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif role == "system":
|
|
||||||
messages.append(SystemMessage(content=turn["content"]))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls: list[Tool] = []
|
|
||||||
if tools:
|
|
||||||
# convert to Tool
|
|
||||||
for tool in tools:
|
|
||||||
if tool["type"] != "function":
|
|
||||||
continue
|
|
||||||
|
|
||||||
function = tool["function"]
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
Tool(
|
|
||||||
function=Function(
|
|
||||||
name=function["name"],
|
|
||||||
description=function["description"],
|
|
||||||
# set parameters to empty dict if not provided
|
|
||||||
parameters=function.get("parameters", {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
|
|
||||||
messages=messages,
|
|
||||||
tools=tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
return chat_completion
|
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
self,
|
self,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
@@ -342,8 +271,8 @@ class HFMistralTokenizer:
|
|||||||
if add_generation_prompt:
|
if add_generation_prompt:
|
||||||
raise NotImplementedError("add_generation_prompt not supported yet")
|
raise NotImplementedError("add_generation_prompt not supported yet")
|
||||||
|
|
||||||
chat_completion: ChatCompletionRequest = (
|
chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai(
|
||||||
self._create_mistral_chat_completion_request(messages, tools)
|
messages, tools
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
||||||
@@ -408,13 +337,16 @@ class HFMistralTokenizer:
|
|||||||
padding_value=IGNORE_INDEX,
|
padding_value=IGNORE_INDEX,
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
attention_mask = None
|
||||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
if "attention_mask" in features[0]:
|
||||||
batch_first=True,
|
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
padding_value=0,
|
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||||
)
|
batch_first=True,
|
||||||
|
padding_value=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
||||||
|
position_ids = None
|
||||||
if "position_ids" in features[0]:
|
if "position_ids" in features[0]:
|
||||||
if self.padding_side == "left":
|
if self.padding_side == "left":
|
||||||
# Likely not needed, but keeping for now
|
# Likely not needed, but keeping for now
|
||||||
@@ -443,22 +375,15 @@ class HFMistralTokenizer:
|
|||||||
pos_seq = torch.cat([pos_seq, pad_positions])
|
pos_seq = torch.cat([pos_seq, pad_positions])
|
||||||
position_ids_list.append(pos_seq)
|
position_ids_list.append(pos_seq)
|
||||||
position_ids = torch.stack(position_ids_list)
|
position_ids = torch.stack(position_ids_list)
|
||||||
else:
|
|
||||||
# Create position_ids if not present
|
|
||||||
seq_len = input_ids.size(1)
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(seq_len, dtype=torch.long)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(input_ids.size(0), -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure all tensors have the same sequence length
|
# Ensure all tensors have the same sequence length
|
||||||
max_seq_len = max(
|
# Check attention mask and position ids if they are present
|
||||||
input_ids.size(1),
|
tensor_lengths = [input_ids.size(1), labels.size(1)]
|
||||||
labels.size(1),
|
if attention_mask is not None:
|
||||||
attention_mask.size(1),
|
tensor_lengths.append(attention_mask.size(1))
|
||||||
position_ids.size(1),
|
if position_ids is not None:
|
||||||
)
|
tensor_lengths.append(position_ids.size(1))
|
||||||
|
max_seq_len = max(tensor_lengths)
|
||||||
|
|
||||||
# TODO: check if trimming is needed? and correct.
|
# TODO: check if trimming is needed? and correct.
|
||||||
|
|
||||||
@@ -492,44 +417,48 @@ class HFMistralTokenizer:
|
|||||||
elif labels.size(1) > max_seq_len:
|
elif labels.size(1) > max_seq_len:
|
||||||
labels = labels[:, :max_seq_len]
|
labels = labels[:, :max_seq_len]
|
||||||
|
|
||||||
if attention_mask.size(1) < max_seq_len:
|
if attention_mask is not None:
|
||||||
pad_len = max_seq_len - attention_mask.size(1)
|
if attention_mask.size(1) < max_seq_len:
|
||||||
if self.padding_side == "right":
|
pad_len = max_seq_len - attention_mask.size(1)
|
||||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
if self.padding_side == "right":
|
||||||
else:
|
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
else:
|
||||||
elif attention_mask.size(1) > max_seq_len:
|
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||||
attention_mask = attention_mask[:, :max_seq_len]
|
elif attention_mask.size(1) > max_seq_len:
|
||||||
|
attention_mask = attention_mask[:, :max_seq_len]
|
||||||
|
|
||||||
if position_ids.size(1) < max_seq_len:
|
if position_ids is not None:
|
||||||
pad_len = max_seq_len - position_ids.size(1)
|
if position_ids.size(1) < max_seq_len:
|
||||||
if self.padding_side == "right":
|
pad_len = max_seq_len - position_ids.size(1)
|
||||||
batch_size = position_ids.size(0)
|
if self.padding_side == "right":
|
||||||
new_position_ids = []
|
batch_size = position_ids.size(0)
|
||||||
for i in range(batch_size):
|
new_position_ids = []
|
||||||
seq = position_ids[i]
|
for i in range(batch_size):
|
||||||
if len(seq) > 0:
|
seq = position_ids[i]
|
||||||
# get last position and pad with sequential values
|
if len(seq) > 0:
|
||||||
last_pos = seq[-1].item()
|
# get last position and pad with sequential values
|
||||||
pad_positions = torch.arange(
|
last_pos = seq[-1].item()
|
||||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
pad_positions = torch.arange(
|
||||||
)
|
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||||
new_seq = torch.cat([seq, pad_positions])
|
)
|
||||||
else:
|
new_seq = torch.cat([seq, pad_positions])
|
||||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
else:
|
||||||
new_position_ids.append(new_seq)
|
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||||
position_ids = torch.stack(new_position_ids)
|
new_position_ids.append(new_seq)
|
||||||
else:
|
position_ids = torch.stack(new_position_ids)
|
||||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
else:
|
||||||
elif position_ids.size(1) > max_seq_len:
|
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||||
position_ids = position_ids[:, :max_seq_len]
|
elif position_ids.size(1) > max_seq_len:
|
||||||
|
position_ids = position_ids[:, :max_seq_len]
|
||||||
|
|
||||||
final_batch = {
|
final_batch = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
}
|
}
|
||||||
|
if attention_mask is not None:
|
||||||
|
final_batch["attention_mask"] = attention_mask
|
||||||
|
if position_ids is not None:
|
||||||
|
final_batch["position_ids"] = position_ids
|
||||||
|
|
||||||
# Handle non-sequence fields (raise error)
|
# Handle non-sequence fields (raise error)
|
||||||
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
||||||
@@ -545,7 +474,7 @@ class HFMistralTokenizer:
|
|||||||
result = {}
|
result = {}
|
||||||
for k, v in final_batch.items():
|
for k, v in final_batch.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
result[k] = v.numpy().astype(np.long)
|
result[k] = v.numpy().astype(np.int64)
|
||||||
else:
|
else:
|
||||||
result[k] = v
|
result[k] = v
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -627,7 +627,7 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Literal["auto"] | bool | None = Field(
|
torch_compile: Literal["auto"] | bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.5.1"
|
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
torch_compile_backend: str | None = Field(
|
torch_compile_backend: str | None = Field(
|
||||||
@@ -1083,9 +1083,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
def check_min_torch_version(self):
|
def check_min_torch_version(self):
|
||||||
if self.env_capabilities and self.env_capabilities.torch_version:
|
if self.env_capabilities and self.env_capabilities.torch_version:
|
||||||
torch_version = self.env_capabilities.torch_version
|
torch_version = self.env_capabilities.torch_version
|
||||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
if version.parse(torch_version) < version.parse("2.6.0"):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
f"torch=={torch_version} not be supported. Please upgrade to torch>=2.6.0."
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -692,7 +692,7 @@ class TestValidation(BaseValidation):
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
"capabilities": {"bf16": False},
|
"capabilities": {"bf16": False},
|
||||||
"env_capabilities": {
|
"env_capabilities": {
|
||||||
"torch_version": "2.5.1",
|
"torch_version": "2.6.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1202,7 +1202,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
)
|
)
|
||||||
|
|
||||||
env_capabilities = {"torch_version": "2.5.1"}
|
env_capabilities = {"torch_version": "2.6.0"}
|
||||||
capabilities = {"bf16": False}
|
capabilities = {"bf16": False}
|
||||||
_ = validate_config(
|
_ = validate_config(
|
||||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
@@ -1244,7 +1244,7 @@ class TestTorchCompileValidation(BaseValidation):
|
|||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
env_capabilities = {"torch_version": "2.5.1"}
|
env_capabilities = {"torch_version": "2.6.0"}
|
||||||
capabilities = {"bf16": True}
|
capabilities = {"bf16": True}
|
||||||
updated_cfg = validate_config(
|
updated_cfg = validate_config(
|
||||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
|||||||
@@ -164,6 +164,14 @@ def fixture_magistral_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="devstral_tokenizer")
|
||||||
|
def fixture_devstral_tokenizer():
|
||||||
|
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||||
|
|
||||||
|
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505")
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||||
|
|||||||
@@ -3,32 +3,50 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||||
|
|
||||||
|
|
||||||
def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
# fmt: off
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("tokenizer_str", "assistant_toolcall_ids"),
|
||||||
|
(
|
||||||
|
("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2)),
|
||||||
|
("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
def test_mistral_chat_template(
|
||||||
|
tokenizer_str: str,
|
||||||
|
assistant_toolcall_ids: tuple[int, ...],
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
):
|
||||||
|
"""Test chat template with the Magistral/Devstral tokenizer"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||||
|
|
||||||
# check bos, eos, pad, unk are accessible properties
|
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
|
||||||
assert magistral_tokenizer.bos_token_id == 1
|
|
||||||
assert magistral_tokenizer.eos_token_id == 2
|
|
||||||
assert magistral_tokenizer.pad_token_id == 11
|
|
||||||
assert magistral_tokenizer.unk_token_id == 0
|
|
||||||
|
|
||||||
assert magistral_tokenizer.pad_token == "<pad>"
|
# check bos, eos, pad, unk are accessible properties
|
||||||
assert magistral_tokenizer.eos_token == "</s>"
|
assert tokenizer.bos_token_id == 1
|
||||||
assert magistral_tokenizer.bos_token == "<s>"
|
assert tokenizer.eos_token_id == 2
|
||||||
assert magistral_tokenizer.unk_token == "<unk>"
|
assert tokenizer.pad_token_id == 11
|
||||||
|
assert tokenizer.unk_token_id == 0
|
||||||
|
|
||||||
|
assert tokenizer.pad_token == "<pad>"
|
||||||
|
assert tokenizer.eos_token == "</s>"
|
||||||
|
assert tokenizer.bos_token == "<s>"
|
||||||
|
assert tokenizer.unk_token == "<unk>"
|
||||||
|
|
||||||
strategy = MistralStrategy(
|
strategy = MistralStrategy(
|
||||||
MistralPrompter(
|
MistralPrompter(
|
||||||
magistral_tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
message_property_mappings={"role": "role", "content": "content"},
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
),
|
),
|
||||||
tokenizer=magistral_tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="turn",
|
train_on_eos="turn",
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -219,7 +237,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
1, # bos
|
1, # bos
|
||||||
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
|
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
|
||||||
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
|
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
|
||||||
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
|
*assistant_toolcall_ids, # assistant tool calling
|
||||||
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
|
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
|
||||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||||
2 # eos
|
2 # eos
|
||||||
@@ -229,7 +247,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
-100, # bos
|
-100, # bos
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
|
||||||
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
|
*assistant_toolcall_ids, # assistant tool calling
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
|
||||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||||
2 # eos
|
2 # eos
|
||||||
@@ -237,7 +255,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# test chat template with tokenize=False
|
# test chat template with tokenize=False
|
||||||
res = magistral_tokenizer.apply_chat_template(
|
res = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||||
@@ -248,7 +266,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
|
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
|
||||||
|
|
||||||
# test encode
|
# test encode
|
||||||
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True)
|
res = tokenizer.encode("Hello, how are you?", add_special_tokens=True)
|
||||||
assert res == [
|
assert res == [
|
||||||
1, # bos
|
1, # bos
|
||||||
22177, # Hello
|
22177, # Hello
|
||||||
@@ -261,16 +279,16 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# test decode no skip special tokens
|
# test decode no skip special tokens
|
||||||
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False)
|
decoded_res = tokenizer.decode(res, skip_special_tokens=False)
|
||||||
|
|
||||||
assert decoded_res == "<s>Hello, how are you?</s>"
|
assert decoded_res == "<s>Hello, how are you?</s>"
|
||||||
|
|
||||||
# test decode skip special tokens
|
# test decode skip special tokens
|
||||||
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True)
|
decoded_res = tokenizer.decode(res, skip_special_tokens=True)
|
||||||
assert decoded_res == "Hello, how are you?"
|
assert decoded_res == "Hello, how are you?"
|
||||||
|
|
||||||
# test encode no special tokens
|
# test encode no special tokens
|
||||||
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False)
|
res = tokenizer.encode("Hello, how are you?", add_special_tokens=False)
|
||||||
assert res == [
|
assert res == [
|
||||||
22177, # Hello
|
22177, # Hello
|
||||||
1044, # ,
|
1044, # ,
|
||||||
@@ -281,10 +299,452 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# test convert ids to tokens
|
# test convert ids to tokens
|
||||||
res = magistral_tokenizer.convert_ids_to_tokens(res)
|
res = tokenizer.convert_ids_to_tokens(res)
|
||||||
# spacing are needed as we are converting without decoding
|
# spacing are needed as we are converting without decoding
|
||||||
assert res == ["Hello", ",", " how", " are", " you", "?"]
|
assert res == ["Hello", ",", " how", " are", " you", "?"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"):
|
||||||
|
"""Test the MistralTokenizer pad method"""
|
||||||
|
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||||
|
|
||||||
|
magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id
|
||||||
|
|
||||||
|
# Test padding with input_ids and labels only
|
||||||
|
features = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
||||||
|
{"input_ids": [7, 8], "labels": [9, 10]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
# Check that input_ids are padded correctly
|
||||||
|
assert result["input_ids"].shape == (2, 3)
|
||||||
|
assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]
|
||||||
|
|
||||||
|
# Check that labels are padded correctly
|
||||||
|
assert result["labels"].shape == (2, 3)
|
||||||
|
assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]
|
||||||
|
|
||||||
|
# Check that attention_mask and position_ids are NOT created
|
||||||
|
assert "attention_mask" not in result
|
||||||
|
assert "position_ids" not in result
|
||||||
|
|
||||||
|
# Test padding with attention_mask
|
||||||
|
features_with_attention = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]},
|
||||||
|
{"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = magistral_tokenizer.pad(
|
||||||
|
features_with_attention, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that attention_mask is padded correctly
|
||||||
|
assert result["attention_mask"].shape == (2, 3)
|
||||||
|
assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]]
|
||||||
|
|
||||||
|
# Test padding with position_ids
|
||||||
|
features_with_position = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]},
|
||||||
|
{"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = magistral_tokenizer.pad(
|
||||||
|
features_with_position, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that position_ids are padded correctly (continuing sequence)
|
||||||
|
assert result["position_ids"].shape == (2, 3)
|
||||||
|
assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]]
|
||||||
|
|
||||||
|
# Test padding with all fields
|
||||||
|
features_all = [
|
||||||
|
{
|
||||||
|
"input_ids": [1, 2, 3],
|
||||||
|
"labels": [4, 5, 6],
|
||||||
|
"attention_mask": [1, 1, 1],
|
||||||
|
"position_ids": [0, 1, 2],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_ids": [7, 8],
|
||||||
|
"labels": [9, 10],
|
||||||
|
"attention_mask": [1, 1],
|
||||||
|
"position_ids": [0, 1],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
# All fields should be present and correctly padded
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert "labels" in result
|
||||||
|
assert "attention_mask" in result
|
||||||
|
assert "position_ids" in result
|
||||||
|
|
||||||
|
# Test padding with all sequences same length
|
||||||
|
features_same_length = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
||||||
|
{"input_ids": [7, 8, 9], "labels": [10, 11, 12]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = magistral_tokenizer.pad(
|
||||||
|
features_same_length, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check match when no padding is needed
|
||||||
|
assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"]
|
||||||
|
assert result["labels"][0].tolist() == features_same_length[0]["labels"]
|
||||||
|
|
||||||
|
assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"]
|
||||||
|
assert result["labels"][1].tolist() == features_same_length[1]["labels"]
|
||||||
|
|
||||||
|
# Test padding with max_length parameter
|
||||||
|
result = magistral_tokenizer.pad(
|
||||||
|
features, padding="max_length", max_length=5, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should pad to max_length
|
||||||
|
assert result["input_ids"].shape == (2, 5)
|
||||||
|
assert result["labels"].shape == (2, 5)
|
||||||
|
|
||||||
|
# Test numpy return type
|
||||||
|
result = magistral_tokenizer.pad(features, padding=True, return_tensors="np")
|
||||||
|
|
||||||
|
# Should return numpy arrays
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
assert isinstance(result["input_ids"], np.ndarray)
|
||||||
|
assert isinstance(result["labels"], np.ndarray)
|
||||||
|
|
||||||
|
# Test unsupported field rejection
|
||||||
|
features_unsupported = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="unsupported_field"):
|
||||||
|
magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
# Test token_type_ids rejection
|
||||||
|
features_token_type = [
|
||||||
|
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="token_type_ids is not supported"):
|
||||||
|
magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
|
||||||
|
"""Test tool calling with the Magistral tokenizer"""
|
||||||
|
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||||
|
|
||||||
|
strategy = MistralStrategy(
|
||||||
|
MistralPrompter(
|
||||||
|
magistral_tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
|
),
|
||||||
|
tokenizer=magistral_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="turn",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test basic tool calling with single function
|
||||||
|
basic_tool_calling = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like in San Francisco?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call12345",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call12345",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": "Sunny, 72°F",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The weather in San Francisco is sunny and 72°F.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(basic_tool_calling)
|
||||||
|
|
||||||
|
# Basic validation
|
||||||
|
assert "input_ids" in res
|
||||||
|
assert "labels" in res
|
||||||
|
assert len(res["input_ids"]) > 0
|
||||||
|
assert len(res["labels"]) == len(res["input_ids"])
|
||||||
|
|
||||||
|
# Decode and verify structure
|
||||||
|
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||||
|
assert (
|
||||||
|
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]'
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
'[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}</s>'
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded
|
||||||
|
assert "The weather in San Francisco is sunny and 72°F.</s>" in decoded
|
||||||
|
|
||||||
|
# Test multiple tool calls in sequence
|
||||||
|
multi_tool_calling = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_numbers",
|
||||||
|
"description": "Add two numbers together",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "First number"},
|
||||||
|
"b": {"type": "number", "description": "Second number"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiply_numbers",
|
||||||
|
"description": "Multiply two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "number", "description": "First number"},
|
||||||
|
"y": {"type": "number", "description": "Second number"},
|
||||||
|
},
|
||||||
|
"required": ["x", "y"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Add 5 and 3, then multiply the result by 2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call12345",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_numbers",
|
||||||
|
"arguments": {"a": 5, "b": 3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call12345",
|
||||||
|
"name": "add_numbers",
|
||||||
|
"content": "8",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call23456",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiply_numbers",
|
||||||
|
"arguments": {"x": 8, "y": 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call23456",
|
||||||
|
"name": "multiply_numbers",
|
||||||
|
"content": "16",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(multi_tool_calling)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
assert len(res["input_ids"]) > 0
|
||||||
|
assert len(res["labels"]) == len(res["input_ids"])
|
||||||
|
|
||||||
|
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||||
|
assert (
|
||||||
|
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]'
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
'[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}</s>' in decoded
|
||||||
|
)
|
||||||
|
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded
|
||||||
|
assert (
|
||||||
|
'[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}</s>'
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded
|
||||||
|
assert (
|
||||||
|
"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>"
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test tool calling with system message
|
||||||
|
system_tool_calling = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_database",
|
||||||
|
"description": "Search for information in database",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Search query"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant with access to a database.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Find information about Python programming",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "search123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_database",
|
||||||
|
"arguments": {"query": "Python programming"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "search123",
|
||||||
|
"name": "search_database",
|
||||||
|
"content": "Python is a high-level programming language known for its simplicity.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(system_tool_calling)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
assert len(res["input_ids"]) > 0
|
||||||
|
assert len(res["labels"]) == len(res["input_ids"])
|
||||||
|
|
||||||
|
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
'<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]'
|
||||||
|
in decoded
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test error handling - missing tool response
|
||||||
|
incomplete_tool_calling = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"description": "Get current time",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What time is it?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "time12345",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The current time is 12:00 PM.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
from mistral_common.exceptions import InvalidMessageStructureException
|
||||||
|
|
||||||
|
try:
|
||||||
|
strategy.tokenize_prompt(incomplete_tool_calling)
|
||||||
|
except InvalidMessageStructureException as e:
|
||||||
|
assert "Not the same number of function calls and responses" in str(e)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": "2.5.1",
|
"torch_version": "2.6.0",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": "2.5.1",
|
"torch_version": "2.6.0",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -184,7 +184,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": "2.5.1",
|
"torch_version": "2.6.0",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -241,7 +241,7 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": "2.5.1",
|
"torch_version": "2.6.0",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user