Compare commits
8 Commits
moekernels
...
3181
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
939023e661 | ||
|
|
6bc959342b | ||
|
|
b3b92687c4 | ||
|
|
55d1be2ae6 | ||
|
|
08d831c3d5 | ||
|
|
7be8740c5c | ||
|
|
c51d6b06c3 | ||
|
|
09959fac70 |
@@ -13,6 +13,7 @@ format:
|
|||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
- [Llava-1.5](#sec-llava-15)
|
- [Llava-1.5](#sec-llava-15)
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||||
- [Voxtral](#sec-voxtral)
|
- [Voxtral](#sec-voxtral)
|
||||||
- [Gemma-3](#sec-gemma-3)
|
- [Gemma-3](#sec-gemma-3)
|
||||||
- [Gemma-3n](#sec-gemma-3n)
|
- [Gemma-3n](#sec-gemma-3n)
|
||||||
@@ -41,7 +42,6 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
# (optional) if doing lora, only finetune the Language model,
|
# (optional) if doing lora, only finetune the Language model,
|
||||||
# leave the vision model and vision tower frozen
|
# leave the vision model and vision tower frozen
|
||||||
@@ -94,10 +94,22 @@ chat_template: llava
|
|||||||
|
|
||||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||||
|
:::
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
```
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Magistral-Small-2509
|
||||||
```
|
```
|
||||||
|
|
||||||
### Voxtral {#sec-voxtral}
|
### Voxtral {#sec-voxtral}
|
||||||
|
|||||||
110
examples/apertus/README.md
Normal file
110
examples/apertus/README.md
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# Finetune Swiss-AI's Apertus with Axolotl
|
||||||
|
|
||||||
|
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. (Optional, highly recommended) Install XIELU CUDA
|
||||||
|
|
||||||
|
```bash
|
||||||
|
## Recommended for reduced VRAM and faster speeds
|
||||||
|
|
||||||
|
# Point to CUDA toolkit directory
|
||||||
|
# For those using our Docker image, use the below path.
|
||||||
|
export CUDA_HOME=/usr/local/cuda
|
||||||
|
|
||||||
|
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/apertus/apertus-8b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 8.7 GiB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
|
||||||
|
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
### XIELU Installation Issues
|
||||||
|
|
||||||
|
#### `ModuleNotFoundError: No module named 'torch'`
|
||||||
|
|
||||||
|
Please check these one by one:
|
||||||
|
- Running in correct environment
|
||||||
|
- Env has PyTorch installed
|
||||||
|
- CUDA toolkit is at `CUDA_HOME`
|
||||||
|
|
||||||
|
If those didn't help, please try the below solutions:
|
||||||
|
|
||||||
|
1. Pass env for CMAKE and try install again:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Git clone the repo and manually hardcode python path:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/nickjbrowning/XIELU
|
||||||
|
cd xielu
|
||||||
|
git checkout 59d6031
|
||||||
|
|
||||||
|
cd xielu
|
||||||
|
nano CMakeLists.txt # or vi depending on your preference
|
||||||
|
```
|
||||||
|
|
||||||
|
```diff
|
||||||
|
execute_process(
|
||||||
|
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||||
|
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||||
|
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
|
||||||
|
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
|
||||||
|
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install . --no-build-isolation --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: swiss-ai/Apertus-8B-Instruct-2509
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -19,6 +19,9 @@ cd axolotl
|
|||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|||||||
@@ -9,10 +9,6 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -9,10 +9,6 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -9,10 +9,6 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|||||||
@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
|
|||||||
pip3 install librosa==0.11.0
|
pip3 install librosa==0.11.0
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the finetuning example:
|
3. Download sample dataset files
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# for text + vision + audio only
|
||||||
|
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
||||||
|
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# text only
|
# text only
|
||||||
|
|||||||
@@ -12,15 +12,6 @@ chat_template: llama3
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
roles:
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# Finetune Magistral Small with Axolotl
|
# Finetune Magistral Small with Axolotl
|
||||||
|
|
||||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||||
|
|
||||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
@@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
### Thinking
|
### Thinking
|
||||||
|
|
||||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||||
|
|
||||||
Example format:
|
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||||
|
|
||||||
```json
|
### Vision
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
|
||||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
|
||||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Example config: `./magistral-small-think-qlora.yaml`.
|
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
|
||||||
|
|
||||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||||
|
|
||||||
Limitations:
|
### Tips
|
||||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
|
||||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||||
@@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet.
|
|||||||
|
|
||||||
## Future Work
|
## Future Work
|
||||||
|
|
||||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
- Add parity to Preference Tuning, RL, etc.
|
||||||
- Add parity to other tokenizer configs like overriding tokens.
|
- Add parity to other tokenizer configs like overriding tokens.
|
||||||
|
|||||||
73
examples/magistral/think/README.md
Normal file
73
examples/magistral/think/README.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Magistral Small Thinking Fine-tuning
|
||||||
|
|
||||||
|
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Before starting, ensure you have:
|
||||||
|
- Installed Axolotl (see [main README](../README.md))
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Run the thinking model fine-tuning:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train magistral-small-think-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 19.1 GiB VRAM.
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||||
|
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||||
|
|
||||||
|
Example format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Options
|
||||||
|
|
||||||
|
The `thinking` section supports an optional `closed` parameter:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "Internal reasoning here...",
|
||||||
|
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||||
|
}
|
||||||
|
```
|
||||||
60
examples/magistral/vision/README.md
Normal file
60
examples/magistral/vision/README.md
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# Magistral Small Vision Fine-tuning
|
||||||
|
|
||||||
|
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Before starting, ensure you have:
|
||||||
|
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install the required vision lib:
|
||||||
|
```bash
|
||||||
|
pip install 'mistral-common[opencv]==1.8.5'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download the example dataset image:
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the fine-tuning:
|
||||||
|
```bash
|
||||||
|
axolotl train magistral-small-vision-24B-qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 17GiB VRAM.
|
||||||
|
|
||||||
|
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
Key differences from text-only model:
|
||||||
|
- `max_tokens: 131072` for inference
|
||||||
|
- Multi-modal dataset format required
|
||||||
|
- Sample packing not supported
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||||
|
|
||||||
|
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{ "type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image", "path": "path/to/image.jpg" }
|
||||||
|
]},
|
||||||
|
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Sample Packing is not supported for multi-modality training currently.
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: mistralai/Magistral-Small-2509
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# Enable to use mistral-common tokenizer
|
||||||
|
tokenizer_use_mistral_common: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
# sample dataset below requires downloading image in advance
|
||||||
|
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||||
|
datasets:
|
||||||
|
- path: Nanobit/text-vision-2k-test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# Enable to use mistral-common tokenizer
|
||||||
|
tokenizer_use_mistral_common: true
|
||||||
|
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
|
|||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
# sample dataset below requires downloading image in advance
|
||||||
|
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: Nanobit/text-vision-2k-test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
@@ -48,8 +51,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
flash_attention: true
|
||||||
sdp_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
@@ -12,15 +12,6 @@ chat_template: phi_3
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
roles:
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -45,8 +45,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
|
flash_attention: true
|
||||||
sdp_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|||||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Finetune Qwen3-Next with Axolotl
|
||||||
|
|
||||||
|
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install Qwen3-Next transformers commit
|
||||||
|
```bash
|
||||||
|
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install FLA for improved performance
|
||||||
|
```bash
|
||||||
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 41.7 GiB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
60
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
60
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 8
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3'
|
|||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the finetuning example:
|
3. Download sample dataset files
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# for text + audio only
|
||||||
|
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# text only
|
# text only
|
||||||
|
|||||||
@@ -70,4 +70,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.5
|
axolotl-contribs-mit==0.0.5
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.8.5
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
|
||||||
)
|
)
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -124,7 +124,6 @@ extras_require = {
|
|||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.8.3",
|
"flash-attn==2.8.3",
|
||||||
"ring-flash-attn>=0.1.7",
|
"ring-flash-attn>=0.1.7",
|
||||||
"yunchang==0.6.0",
|
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.17.5",
|
"deepspeed==0.17.5",
|
||||||
|
|||||||
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
else:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl is RLType.SIMPO:
|
if self.cfg.rl is RLType.SIMPO:
|
||||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.cpo_alpha is not None:
|
if self.cfg.cpo_alpha is not None:
|
||||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
|
# Handle when max_prompt_length == max_length from defaults
|
||||||
|
# CPOTrainer requires strictly less than
|
||||||
|
if (
|
||||||
|
training_args_kwargs["max_prompt_length"]
|
||||||
|
== training_args_kwargs["max_length"]
|
||||||
|
):
|
||||||
|
training_args_kwargs["max_prompt_length"] -= 1
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.KTO:
|
elif self.cfg.rl is RLType.KTO:
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.GRPO:
|
elif self.cfg.rl is RLType.GRPO:
|
||||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Mapping
|
|||||||
|
|
||||||
def chat_message_transform_builder(
|
def chat_message_transform_builder(
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
conversations_field: str = "conversations",
|
conversations_field: str = "messages",
|
||||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||||
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
|
|||||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
conversations_field (str, optional):
|
conversations_field (str, optional):
|
||||||
The field name of the conversations. Defaults to "conversations".
|
The field name of the conversations. Defaults to "messages".
|
||||||
message_field_role (str | list[str], optional):
|
message_field_role (str | list[str], optional):
|
||||||
The field name of the role. Defaults to "role".
|
The field name of the role.
|
||||||
message_field_content (str | list[str], optional):
|
message_field_content (str | list[str], optional):
|
||||||
The field name of the message content. Defaults to "content".
|
The field name of the message content.
|
||||||
message_field_training (str | list[str], optional):
|
message_field_training (str | list[str], optional):
|
||||||
The field name of the train/weight. Defaults to "weight".
|
The field name of the train/weight.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable:
|
Callable:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
@@ -37,4 +36,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
if cfg.dpo_use_logits_to_keep is not None:
|
if cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||||
|
if cfg.dpo_disable_output_fp32 is not None:
|
||||||
|
training_args_kwargs["disable_output_fp32"] = cfg.dpo_disable_output_fp32
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_norm_loss: bool | None = False
|
dpo_norm_loss: bool | None = False
|
||||||
|
disable_output_fp32: bool | None = False
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -65,6 +65,7 @@ plugins:
|
|||||||
- qwen2_5_vl
|
- qwen2_5_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
|
- qwen3_next
|
||||||
- smollm3
|
- smollm3
|
||||||
- seed_oss
|
- seed_oss
|
||||||
- voxtral
|
- voxtral
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class PatchManager:
|
|||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
self._apply_chunked_cross_entropy_patch()
|
||||||
|
self._apply_dpo_disable_output_fp32_patch()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
@@ -68,11 +69,12 @@ class PatchManager:
|
|||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_fsdp2_bnb_patches()
|
self._apply_fsdp2_bnb_patches()
|
||||||
self._apply_patch_deepspeed_zero3()
|
self._apply_patch_deepspeed_zero3()
|
||||||
|
self._apply_voxtral_patches()
|
||||||
|
self._apply_apertus_patches()
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
self._apply_voxtral_patches()
|
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
@@ -106,6 +108,16 @@ class PatchManager:
|
|||||||
else:
|
else:
|
||||||
patch_chunked_ce_loss_fn()
|
patch_chunked_ce_loss_fn()
|
||||||
|
|
||||||
|
def _apply_dpo_disable_output_fp32_patch(self):
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
if self.cfg.rl in {RLType.DPO, RLType.IPO} and self.cfg.dpo_disable_output_fp32:
|
||||||
|
from axolotl.monkeypatch.trainer.dpo_chunked import (
|
||||||
|
patch_dpo_disable_output_fp32,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_dpo_disable_output_fp32()
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.context_parallel_size > 1 or (
|
if self.cfg.context_parallel_size > 1 or (
|
||||||
@@ -168,6 +180,20 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_llama4_linearized_modeling()
|
patch_llama4_linearized_modeling()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||||
|
patch_qwen3_next_modeling_packing,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_qwen3_next_modeling_packing()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||||
|
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||||
|
apply_mistral_tokenizer_image_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_mistral_tokenizer_image_patch()
|
||||||
|
|
||||||
def _apply_fp8_patches(self):
|
def _apply_fp8_patches(self):
|
||||||
"""Apply patches for FP8 support."""
|
"""Apply patches for FP8 support."""
|
||||||
if self.cfg.fp8:
|
if self.cfg.fp8:
|
||||||
@@ -334,6 +360,13 @@ class PatchManager:
|
|||||||
|
|
||||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||||
|
|
||||||
|
if self.model_config.model_type in ("mistral3", "llava"):
|
||||||
|
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||||
|
apply_patch_is_packed_sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_patch_is_packed_sequence()
|
||||||
|
|
||||||
def _patch_loss_llama(self):
|
def _patch_loss_llama(self):
|
||||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||||
if not self.cfg.is_llama_derived_model:
|
if not self.cfg.is_llama_derived_model:
|
||||||
@@ -479,3 +512,12 @@ class PatchManager:
|
|||||||
apply_deepspeed_patches()
|
apply_deepspeed_patches()
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
||||||
|
|
||||||
|
def _apply_apertus_patches(self):
|
||||||
|
"""Apply patches for Apertus model."""
|
||||||
|
if self.cfg.model_config_type == "apertus":
|
||||||
|
from axolotl.monkeypatch.models.apertus.activation import (
|
||||||
|
patch_apertus_xielu_activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_apertus_xielu_activation()
|
||||||
|
|||||||
@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
if cfg.processor_type:
|
if cfg.processor_type:
|
||||||
processor_cls = getattr(transformers, cfg.processor_type)
|
processor_cls = getattr(transformers, cfg.processor_type)
|
||||||
|
|
||||||
|
if cfg.tokenizer_use_mistral_common:
|
||||||
|
from axolotl.utils.mistral import Mistral3Processor
|
||||||
|
|
||||||
|
return Mistral3Processor(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
processor = processor_cls.from_pretrained(
|
processor = processor_cls.from_pretrained(
|
||||||
cfg.processor_config,
|
cfg.processor_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
|||||||
@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
|
|
||||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||||
"""Load mistral-common tokenizer"""
|
"""Load mistral-common tokenizer"""
|
||||||
from transformers import tokenization_mistral_common
|
|
||||||
|
|
||||||
from axolotl.utils.mistral import HFMistralTokenizer
|
from axolotl.utils.mistral import HFMistralTokenizer
|
||||||
|
|
||||||
# patch
|
|
||||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
|
||||||
|
|
||||||
# Load the HF-compatible wrapper around MistralTokenizer
|
# Load the HF-compatible wrapper around MistralTokenizer
|
||||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||||
|
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Monkeypatch for Apertus to dtype mismatch in XIELU act"""
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def patch_apertus_xielu_activation():
|
||||||
|
try:
|
||||||
|
from transformers.activations import XIELUActivation
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError(
|
||||||
|
"Cannot import XIELUActivation. "
|
||||||
|
"Please make sure to update your transformers version >= 4.56.1."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
from transformers.activations import logger
|
||||||
|
|
||||||
|
# Store the original method
|
||||||
|
old_fn = XIELUActivation._xielu_cuda
|
||||||
|
|
||||||
|
def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:
|
||||||
|
"""Firewall function to prevent torch.compile from seeing .item() calls"""
|
||||||
|
original_shape = x.shape
|
||||||
|
# CUDA kernel expects 3D tensors, reshape if needed
|
||||||
|
while x.dim() < 3:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
if x.dim() > 3:
|
||||||
|
x = x.view(-1, 1, x.size(-1))
|
||||||
|
if original_shape != x.shape:
|
||||||
|
logger.warning_once(
|
||||||
|
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
|
||||||
|
original_shape,
|
||||||
|
x.shape,
|
||||||
|
)
|
||||||
|
result = self._xielu_cuda_obj.forward(
|
||||||
|
x,
|
||||||
|
self.alpha_p.to(x.dtype),
|
||||||
|
self.alpha_n.to(x.dtype),
|
||||||
|
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
||||||
|
self._beta_scalar,
|
||||||
|
self._eps_scalar,
|
||||||
|
self.with_vector_loads,
|
||||||
|
)
|
||||||
|
return result.view(original_shape)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
XIELUActivation._xielu_cuda = _xielu_cuda_fixed
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
"""Restore the original method"""
|
||||||
|
XIELUActivation._xielu_cuda = old_fn
|
||||||
|
|
||||||
|
return unpatch
|
||||||
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mistral_tokenizer_image_patch():
|
||||||
|
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
|
||||||
|
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||||
|
|
||||||
|
# Get original source
|
||||||
|
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
|
||||||
|
original_source, _ = detab_code(original_source)
|
||||||
|
|
||||||
|
# Define the replacement
|
||||||
|
original_tensor_conversion = (
|
||||||
|
" pixel_values = torch.tensor(images)"
|
||||||
|
)
|
||||||
|
|
||||||
|
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
|
||||||
|
pixel_values = torch.tensor(np.array(images))
|
||||||
|
else:
|
||||||
|
pixel_values = torch.tensor(images)"""
|
||||||
|
|
||||||
|
# Apply the replacement
|
||||||
|
if original_tensor_conversion in original_source:
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
original_tensor_conversion, patched_tensor_conversion
|
||||||
|
)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def apply_chat_template(",
|
||||||
|
"def patched_apply_chat_template(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load necessary imports from the module
|
||||||
|
module_name = MistralCommonTokenizer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Detect what needs to be imported
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source and not item.startswith("_"):
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
# Execute imports in global scope
|
||||||
|
if items_to_import:
|
||||||
|
exec( # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also need standard imports that might be used
|
||||||
|
exec("import numpy as np", globals()) # nosec B102
|
||||||
|
exec("import torch", globals()) # nosec B102
|
||||||
|
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
|
||||||
|
exec("from pathlib import Path", globals()) # nosec B102
|
||||||
|
|
||||||
|
# Import other dependencies that might be needed
|
||||||
|
try:
|
||||||
|
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
|
||||||
|
exec(
|
||||||
|
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
|
||||||
|
globals(),
|
||||||
|
) # nosec B102
|
||||||
|
exec("from transformers.utils import logging", globals()) # nosec B102
|
||||||
|
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
|
||||||
|
except ImportError as e:
|
||||||
|
LOG.warning(f"Could not import some dependencies: {e}")
|
||||||
|
|
||||||
|
# Execute the patched source
|
||||||
|
exec(patched_source, globals()) # nosec B102
|
||||||
|
|
||||||
|
# Replace the method
|
||||||
|
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
|
||||||
|
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
|
||||||
|
else:
|
||||||
|
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
|
||||||
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def apply_patch_is_packed_sequence():
|
||||||
|
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||||
|
from transformers import modeling_flash_attention_utils
|
||||||
|
|
||||||
|
def fixed_is_packed_sequence(position_ids, batch_size):
|
||||||
|
"""
|
||||||
|
Check the position ids whether packed sequences are indicated or not
|
||||||
|
1. Position ids exist
|
||||||
|
2. Flattened sequences only are supported
|
||||||
|
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
|
||||||
|
"""
|
||||||
|
if position_ids is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if position_ids.ndim == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
|
||||||
|
|
||||||
|
increasing_position_sequences = (
|
||||||
|
torch.arange(position_ids.shape[1], device=position_ids.device)
|
||||||
|
+ position_ids.min()
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
batch_size == 1
|
||||||
|
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
old_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
"""Restore the original method"""
|
||||||
|
modeling_flash_attention_utils._is_packed_sequence = old_fn
|
||||||
|
|
||||||
|
return unpatch
|
||||||
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Qwen3_Next model monkeypatches."""
|
||||||
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens(position_ids):
|
||||||
|
"""
|
||||||
|
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
|
||||||
|
"""
|
||||||
|
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||||
|
|
||||||
|
position_ids = position_ids.view(-1)
|
||||||
|
indices_q = (position_ids == 0).nonzero().view(-1)
|
||||||
|
|
||||||
|
cu_seq_lens_q = torch.cat(
|
||||||
|
(
|
||||||
|
indices_q.to(**tensor_kwargs),
|
||||||
|
torch.tensor(position_ids.size(), **tensor_kwargs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cu_seq_lens_q
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_next_decoder_layer():
|
||||||
|
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
|
||||||
|
try:
|
||||||
|
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||||
|
Qwen3NextDecoderLayer,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store original forward method
|
||||||
|
original_decoder_forward = Qwen3NextDecoderLayer.forward
|
||||||
|
|
||||||
|
def patched_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Token Mixer
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
hidden_states = self.linear_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
cache_params=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
elif self.layer_type == "full_attention":
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, _ = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
# For the MoE layers, we need to unpack
|
||||||
|
if isinstance(hidden_states, Tuple):
|
||||||
|
hidden_states, _ = hidden_states
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
# Apply the patches
|
||||||
|
Qwen3NextDecoderLayer.forward = patched_decoder_forward
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
"""Restore the original forward method"""
|
||||||
|
Qwen3NextDecoderLayer.forward = original_decoder_forward
|
||||||
|
|
||||||
|
return unpatch
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_next_gateddelta_layer():
|
||||||
|
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
|
||||||
|
try:
|
||||||
|
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||||
|
Qwen3NextDynamicCache,
|
||||||
|
Qwen3NextGatedDeltaNet,
|
||||||
|
apply_mask_to_padding_states,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store original forward method
|
||||||
|
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
|
||||||
|
|
||||||
|
def patched_gated_delta_net_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params: Optional[Qwen3NextDynamicCache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
# Set up dimensions for reshapes later
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
use_precomputed_states = (
|
||||||
|
cache_params is not None
|
||||||
|
and cache_params.has_previous_state
|
||||||
|
and seq_len == 1
|
||||||
|
and cache_position is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# getting projected states from cache if it exists
|
||||||
|
if cache_params is not None:
|
||||||
|
conv_state = cache_params.conv_states[self.layer_idx]
|
||||||
|
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
||||||
|
|
||||||
|
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
||||||
|
projected_states_ba = self.in_proj_ba(hidden_states)
|
||||||
|
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||||
|
projected_states_qkvz, projected_states_ba
|
||||||
|
)
|
||||||
|
query, key, value = (
|
||||||
|
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
|
|
||||||
|
if use_precomputed_states:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||||
|
mixed_qkv = self.causal_conv1d_update(
|
||||||
|
mixed_qkv,
|
||||||
|
conv_state,
|
||||||
|
self.conv1d.weight.squeeze(1),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if cache_params is not None:
|
||||||
|
conv_state = F.pad(
|
||||||
|
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||||
|
)
|
||||||
|
cache_params.conv_states[self.layer_idx] = conv_state
|
||||||
|
if self.causal_conv1d_fn is not None:
|
||||||
|
mixed_qkv = self.causal_conv1d_fn(
|
||||||
|
x=mixed_qkv,
|
||||||
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
seq_idx=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
|
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
|
query, key, value = torch.split(
|
||||||
|
mixed_qkv,
|
||||||
|
[
|
||||||
|
self.key_dim,
|
||||||
|
self.key_dim,
|
||||||
|
self.value_dim,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
|
||||||
|
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
|
||||||
|
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
|
||||||
|
|
||||||
|
beta = b.sigmoid()
|
||||||
|
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||||
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||||
|
if self.num_v_heads // self.num_k_heads > 1:
|
||||||
|
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
|
||||||
|
if not use_precomputed_states:
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||||
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
initial_state=None,
|
||||||
|
output_final_state=cache_params is not None,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
initial_state=recurrent_state,
|
||||||
|
output_final_state=cache_params is not None,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
if cache_params is not None:
|
||||||
|
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
||||||
|
|
||||||
|
z_shape_og = z.shape
|
||||||
|
# reshape input data into 2D tensor
|
||||||
|
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||||
|
z = z.reshape(-1, z.shape[-1])
|
||||||
|
core_attn_out = self.norm(core_attn_out, z)
|
||||||
|
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||||
|
core_attn_out = core_attn_out.reshape(
|
||||||
|
core_attn_out.shape[0], core_attn_out.shape[1], -1
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.out_proj(core_attn_out)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Apply the patches
|
||||||
|
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
"""Restore the original forward method"""
|
||||||
|
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
|
||||||
|
|
||||||
|
return unpatch
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_next_imports():
|
||||||
|
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
|
||||||
|
try:
|
||||||
|
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("Qwen3Next model not found, skipping import patch")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original values for unpatch
|
||||||
|
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
|
||||||
|
original_chunk_gated_delta_rule = getattr(
|
||||||
|
qwen3_modeling, "chunk_gated_delta_rule", None
|
||||||
|
)
|
||||||
|
original_fused_recurrent_gated_delta_rule = getattr(
|
||||||
|
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
|
||||||
|
)
|
||||||
|
original_is_fast_path_available = getattr(
|
||||||
|
qwen3_modeling, "is_fast_path_available", False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fla.modules import FusedRMSNormGated
|
||||||
|
from fla.ops.gated_delta_rule import (
|
||||||
|
chunk_gated_delta_rule,
|
||||||
|
fused_recurrent_gated_delta_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
|
||||||
|
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||||
|
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||||
|
fused_recurrent_gated_delta_rule
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force is_fast_path_available to be True
|
||||||
|
# fla has triton kernels for causal_conv1d
|
||||||
|
qwen3_modeling.is_fast_path_available = True
|
||||||
|
except ImportError:
|
||||||
|
qwen3_modeling.chunk_gated_delta_rule = None
|
||||||
|
qwen3_modeling.fused_recurrent_gated_delta_rule = None
|
||||||
|
qwen3_modeling.FusedRMSNormGated = None
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
"""Restore the original import values"""
|
||||||
|
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
|
||||||
|
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
|
||||||
|
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||||
|
original_fused_recurrent_gated_delta_rule
|
||||||
|
)
|
||||||
|
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
|
||||||
|
|
||||||
|
return unpatch
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_next_modeling_packing():
|
||||||
|
"""Apply all Qwen3Next model patches."""
|
||||||
|
patch_qwen3_next_imports()
|
||||||
|
patch_qwen3_next_decoder_layer()
|
||||||
|
patch_qwen3_next_gateddelta_layer()
|
||||||
|
|
||||||
|
LOG.info("Applied Qwen3Next patch for packing")
|
||||||
@@ -11,6 +11,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"apertus",
|
||||||
"mllama_text_model",
|
"mllama_text_model",
|
||||||
"llama",
|
"llama",
|
||||||
"llama4",
|
"llama4",
|
||||||
@@ -20,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
|
"qwen3_next",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
|
|||||||
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Monkeypatch helpers to reduce fp32 materialization during DPO training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_patch_targets(model) -> Iterable[torch.nn.Module]:
|
||||||
|
current = model
|
||||||
|
seen: set[int] = set()
|
||||||
|
while current is not None and id(current) not in seen:
|
||||||
|
seen.add(id(current))
|
||||||
|
yield current
|
||||||
|
current = getattr(current, "module", None)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_unwrapped_forward(module):
|
||||||
|
forward = getattr(module, "forward", None)
|
||||||
|
if forward is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(forward, "__wrapped__"):
|
||||||
|
unwrapped = forward.__wrapped__
|
||||||
|
return MethodType(unwrapped, module)
|
||||||
|
|
||||||
|
original = getattr(module, "_original_forward", None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
func = original.__func__ if hasattr(original, "__func__") else original
|
||||||
|
return MethodType(func, module)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _temporarily_disable_output_fp32(model):
|
||||||
|
patched = []
|
||||||
|
for target in _iter_patch_targets(model):
|
||||||
|
replacement = _resolve_unwrapped_forward(target)
|
||||||
|
if replacement is None:
|
||||||
|
continue
|
||||||
|
patched.append((target, target.forward, replacement))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for module, _, replacement in patched:
|
||||||
|
module.forward = replacement
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for module, original_forward, _ in reversed(patched):
|
||||||
|
module.forward = original_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_fp32_outputs(output: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
if not isinstance(output, dict):
|
||||||
|
return output
|
||||||
|
|
||||||
|
for key, value in output.items():
|
||||||
|
if torch.is_tensor(value) and value.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
output[key] = value.float()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def patch_dpo_disable_output_fp32():
|
||||||
|
"""Patch TRL's DPOTrainer to skip Accelerate's convert_to_fp32 wrapper when requested."""
|
||||||
|
global _PATCHED
|
||||||
|
if _PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
original_concatenated_forward = DPOTrainer.concatenated_forward
|
||||||
|
|
||||||
|
def patched_concatenated_forward(self, model, batch, is_ref_model: bool = False):
|
||||||
|
if not getattr(self.args, "disable_output_fp32", False):
|
||||||
|
return original_concatenated_forward(
|
||||||
|
self, model, batch, is_ref_model=is_ref_model
|
||||||
|
)
|
||||||
|
|
||||||
|
with _temporarily_disable_output_fp32(model):
|
||||||
|
result = original_concatenated_forward(
|
||||||
|
self, model, batch, is_ref_model=is_ref_model
|
||||||
|
)
|
||||||
|
return _cast_fp32_outputs(result)
|
||||||
|
|
||||||
|
DPOTrainer.concatenated_forward = patched_concatenated_forward
|
||||||
|
_PATCHED = True
|
||||||
@@ -11,6 +11,7 @@ from transformers.image_utils import load_image
|
|||||||
|
|
||||||
from axolotl.utils.dict import remove_none_values
|
from axolotl.utils.dict import remove_none_values
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Mistral3"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: Mistral3Processor,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
special_ids = (
|
||||||
|
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_token = special_ids.img
|
||||||
|
self.image_break_token = special_ids.img_break
|
||||||
|
self.image_end_token = special_ids.img_end
|
||||||
|
|
||||||
|
def process_labels(self, input_ids):
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
labels[labels == self.image_token] = -100
|
||||||
|
labels[labels == self.image_break_token] = -100
|
||||||
|
labels[labels == self.image_end_token] = -100
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def get_processing_strategy(
|
def get_processing_strategy(
|
||||||
processor: ProcessorMixin,
|
processor: ProcessorMixin,
|
||||||
chat_template,
|
chat_template,
|
||||||
@@ -463,6 +494,11 @@ def get_processing_strategy(
|
|||||||
**processing_kwargs,
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(processor, Mistral3Processor):
|
||||||
|
return Mistral3ProcessingStrategy(
|
||||||
|
**processing_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# llama3_2_vision, llama4, llava
|
# llama3_2_vision, llama4, llava
|
||||||
# mistral_v7_tekken, pixtral, lfm2vl
|
# mistral_v7_tekken, pixtral, lfm2vl
|
||||||
return ProcessingStrategy(
|
return ProcessingStrategy(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Init for `axolotl.utils.mistral` module."""
|
"""Init for `axolotl.utils.mistral` module."""
|
||||||
|
|
||||||
|
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||||
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||||
|
|
||||||
__all__ = ["HFMistralTokenizer"]
|
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]
|
||||||
|
|||||||
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Processor for Mistral3 multimodal models with image support"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.processing_utils import ProcessingKwargs
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3ProcessorKwargs(ProcessingKwargs):
|
||||||
|
_defaults: Dict[str, Dict[str, Any]] = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": True,
|
||||||
|
},
|
||||||
|
"common_kwargs": {
|
||||||
|
"return_tensors": "pt",
|
||||||
|
"return_dict": True,
|
||||||
|
"tokenize": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3Processor(ProcessorMixin):
|
||||||
|
"""
|
||||||
|
Processor for Mistral3 multimodal models that handles text and images.
|
||||||
|
Wraps HFMistralTokenizer and adds image processing capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["tokenizer"]
|
||||||
|
tokenizer_class = "HFMistralTokenizer"
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: HFMistralTokenizer):
|
||||||
|
# Don't call super().__init__ to avoid the class validation issue
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_template(self) -> None:
|
||||||
|
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_tokenizer(self) -> None:
|
||||||
|
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _merge_kwargs(
|
||||||
|
self, processor_kwargs_class: Any, **kwargs: Any
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Merge kwargs with defaults similar to ProcessorMixin"""
|
||||||
|
defaults = processor_kwargs_class._defaults
|
||||||
|
output_kwargs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for kwarg_type, default_values in defaults.items():
|
||||||
|
output_kwargs[kwarg_type] = {**default_values}
|
||||||
|
|
||||||
|
# Update with provided kwargs
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
# Try to match key to appropriate kwarg type
|
||||||
|
if key in ["padding", "truncation", "max_length"]:
|
||||||
|
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||||
|
elif key in ["return_tensors", "return_dict", "tokenize"]:
|
||||||
|
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
|
||||||
|
else:
|
||||||
|
# Add to text_kwargs by default
|
||||||
|
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||||
|
|
||||||
|
return output_kwargs
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[BatchFeature, str, list[str]]:
|
||||||
|
"""
|
||||||
|
Apply chat template with image support for Mistral3.
|
||||||
|
|
||||||
|
Similar to VoxtralProcessor, this method extracts images from the conversation,
|
||||||
|
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
|
||||||
|
to the result.
|
||||||
|
"""
|
||||||
|
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||||
|
text_kwargs = output_kwargs["text_kwargs"]
|
||||||
|
common_kwargs = output_kwargs["common_kwargs"]
|
||||||
|
|
||||||
|
return_tensors = common_kwargs.pop("return_tensors", "pt")
|
||||||
|
if return_tensors != "pt":
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
|
||||||
|
)
|
||||||
|
|
||||||
|
return_dict = common_kwargs.pop("return_dict", False)
|
||||||
|
tokenize = common_kwargs.pop("tokenize", False)
|
||||||
|
|
||||||
|
# Determine if batched
|
||||||
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
|
isinstance(conversation[0], (list, tuple))
|
||||||
|
or hasattr(conversation[0], "content")
|
||||||
|
):
|
||||||
|
is_batched = True
|
||||||
|
conversations = conversation
|
||||||
|
else:
|
||||||
|
is_batched = False
|
||||||
|
conversations = [conversation] # type: ignore
|
||||||
|
|
||||||
|
# Call tokenizer's apply_chat_template
|
||||||
|
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
|
||||||
|
tokenizer_kwargs["return_tensors"] = return_tensors
|
||||||
|
tokenizer_kwargs["tokenize"] = tokenize
|
||||||
|
tokenizer_kwargs["return_dict"] = return_dict
|
||||||
|
|
||||||
|
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
|
||||||
|
conversations,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenize:
|
||||||
|
if return_dict:
|
||||||
|
# The tokenizer already handles pixel_values, we just need to add image_sizes
|
||||||
|
if hasattr(encoded_instruct_inputs, "items"):
|
||||||
|
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
|
||||||
|
elif hasattr(encoded_instruct_inputs, "data"):
|
||||||
|
data = encoded_instruct_inputs.data # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown data type")
|
||||||
|
|
||||||
|
if "pixel_values" in data:
|
||||||
|
pixel_values = data["pixel_values"]
|
||||||
|
|
||||||
|
# MistralTokenizer returns a Double, so we convert to fp32
|
||||||
|
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
|
||||||
|
# Since tensor is homogeneous, all images have same H, W
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
|
||||||
|
data["image_sizes"] = image_sizes
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
return encoded_instruct_inputs[0]
|
||||||
|
|
||||||
|
return encoded_instruct_inputs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Optional[
|
||||||
|
Union[
|
||||||
|
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
||||||
|
]
|
||||||
|
],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Forward text processing to the tokenizer.
|
||||||
|
This method does not support images - use apply_chat_template instead.
|
||||||
|
"""
|
||||||
|
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||||
|
text_kwargs = output_kwargs["text_kwargs"]
|
||||||
|
common_kwargs = output_kwargs["common_kwargs"]
|
||||||
|
|
||||||
|
out = self.tokenizer(text, **text_kwargs)
|
||||||
|
return BatchFeature(
|
||||||
|
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
|
||||||
|
)
|
||||||
@@ -160,6 +160,12 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
|
dpo_disable_output_fp32: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Set to true to bypass Accelerate's automatic fp32 upcast in DPO forward passes and rely on chunked computations for lower VRAM usage."
|
||||||
|
},
|
||||||
|
)
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
@@ -436,8 +442,8 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int | None = Field(
|
||||||
default=512,
|
default=None,
|
||||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||||
)
|
)
|
||||||
sample_packing: bool | None = Field(
|
sample_packing: bool | None = Field(
|
||||||
|
|||||||
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Integration tests for MistralCommonTokenizer patches."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistralTokenizerPatchIntegration:
|
||||||
|
"""Test MistralCommonTokenizer patch integration."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_mistral_tokenizer_image_patch(self):
|
||||||
|
"""Test that MistralCommonTokenizer image patch can be applied."""
|
||||||
|
try:
|
||||||
|
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("MistralCommonTokenizer not available")
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||||
|
apply_mistral_tokenizer_image_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
|
||||||
|
|
||||||
|
# Apply patch
|
||||||
|
apply_mistral_tokenizer_image_patch()
|
||||||
|
|
||||||
|
# Verify patch was applied
|
||||||
|
assert (
|
||||||
|
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
|
||||||
|
), "apply_chat_template was not patched"
|
||||||
|
|
||||||
|
# Verify the method is still callable
|
||||||
|
assert callable(MistralCommonTokenizer.apply_chat_template), (
|
||||||
|
"Patched method is not callable"
|
||||||
|
)
|
||||||
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Integration tests for Pixtral Flash Attention patches."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestPixtralFlashAttentionPatchIntegration:
|
||||||
|
"""Test Pixtral Flash Attention patch integration."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_pixtral_flash_attention_patch(self):
|
||||||
|
"""Test that Pixtral Flash Attention patch can be applied and works correctly."""
|
||||||
|
try:
|
||||||
|
from transformers import modeling_flash_attention_utils
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Flash Attention utils not available")
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||||
|
apply_patch_is_packed_sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence
|
||||||
|
|
||||||
|
# Apply patch and get unpatch function
|
||||||
|
unpatch_fn = apply_patch_is_packed_sequence()
|
||||||
|
|
||||||
|
# Verify patch was applied
|
||||||
|
assert (
|
||||||
|
modeling_flash_attention_utils._is_packed_sequence
|
||||||
|
!= original_is_packed_sequence
|
||||||
|
), "_is_packed_sequence was not patched"
|
||||||
|
|
||||||
|
# Test the patched function with 1D position_ids
|
||||||
|
patched_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||||
|
|
||||||
|
# Test 1D position_ids 1 sequence
|
||||||
|
position_ids_1d = torch.tensor([0, 1, 2, 3])
|
||||||
|
result = patched_fn(position_ids_1d, batch_size=1)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is False, "1D sequential position_ids should not be packed"
|
||||||
|
|
||||||
|
# Test 1D packed 2 sequences
|
||||||
|
position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])
|
||||||
|
result = patched_fn(position_ids_1d_packed, batch_size=1)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is True, "1D packed position_ids should be detected as packed"
|
||||||
|
|
||||||
|
# Test 2D packed 2 sequences
|
||||||
|
position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])
|
||||||
|
result = patched_fn(position_ids_2d_packed, batch_size=1)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is True, "2D packed position_ids should be detected as packed"
|
||||||
|
|
||||||
|
# Test 2D 1 sequence
|
||||||
|
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||||
|
result = patched_fn(position_ids_2d_normal, batch_size=1)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is False, "2D sequential position_ids should not be packed"
|
||||||
|
|
||||||
|
# Test 2D batch size 2
|
||||||
|
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
|
||||||
|
result = patched_fn(position_ids_2d_normal, batch_size=2)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is False, "2D position_ids batch 2 should not be packed"
|
||||||
|
|
||||||
|
# Test None case
|
||||||
|
result = patched_fn(None, batch_size=1)
|
||||||
|
assert isinstance(result, bool), "Function should return a boolean"
|
||||||
|
assert result is False, "None position_ids should return False"
|
||||||
|
|
||||||
|
# Test unpatch function
|
||||||
|
unpatch_fn()
|
||||||
|
assert (
|
||||||
|
modeling_flash_attention_utils._is_packed_sequence
|
||||||
|
== original_is_packed_sequence
|
||||||
|
), "unpatch function did not restore original method"
|
||||||
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Integration tests for Qwen3 Next modeling patches."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Skip entire module if qwen3_next not available
|
||||||
|
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3NextModelingPatchIntegration:
|
||||||
|
"""Test Qwen3 Next modeling patch integration."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_qwen3_next_decoder_layer_patch(self):
|
||||||
|
"""Test that Qwen3Next decoder layer patch can be applied."""
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||||
|
patch_qwen3_next_decoder_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
|
||||||
|
|
||||||
|
# Apply patch and get unpatch function
|
||||||
|
unpatch_fn = patch_qwen3_next_decoder_layer()
|
||||||
|
|
||||||
|
# Verify patch was applied
|
||||||
|
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
|
||||||
|
"decoder layer forward method was not patched"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the method is still callable
|
||||||
|
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
|
||||||
|
"Patched method is not callable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test unpatch function
|
||||||
|
if unpatch_fn:
|
||||||
|
unpatch_fn()
|
||||||
|
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
|
||||||
|
"unpatch function did not restore original method"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_qwen3_next_gateddelta_layer_patch(self):
|
||||||
|
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||||
|
patch_qwen3_next_gateddelta_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
|
||||||
|
|
||||||
|
# Apply patch and get unpatch function
|
||||||
|
unpatch_fn = patch_qwen3_next_gateddelta_layer()
|
||||||
|
|
||||||
|
# Verify patch was applied
|
||||||
|
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
|
||||||
|
"GatedDeltaNet forward method was not patched"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the method is still callable
|
||||||
|
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
|
||||||
|
"Patched method is not callable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test unpatch function
|
||||||
|
if unpatch_fn:
|
||||||
|
unpatch_fn()
|
||||||
|
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
|
||||||
|
"unpatch function did not restore original method"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_qwen3_next_imports_patch(self):
|
||||||
|
"""Test that Qwen3Next imports patch can be applied without errors."""
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||||
|
patch_qwen3_next_imports,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply patch - should not raise any exceptions even if modules unavailable
|
||||||
|
unpatch_fn = patch_qwen3_next_imports()
|
||||||
|
|
||||||
|
# Test that unpatch function is returned (or None if skipped)
|
||||||
|
assert unpatch_fn is None or callable(unpatch_fn), (
|
||||||
|
"patch_qwen3_next_imports should return None or callable unpatch function"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_qwen3_next_modeling_packing_patch(self):
|
||||||
|
"""Test that all Qwen3Next modeling patches can be applied together."""
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||||
|
patch_qwen3_next_modeling_packing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should not raise any exceptions
|
||||||
|
patch_qwen3_next_modeling_packing()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_get_cu_seqlens_utility():
|
||||||
|
"""Test the get_cu_seqlens utility function."""
|
||||||
|
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
|
||||||
|
|
||||||
|
# Test with simple position_ids
|
||||||
|
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids)
|
||||||
|
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
|
||||||
|
|
||||||
|
# Should return tensor with start positions and total length
|
||||||
|
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||||
|
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"
|
||||||
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Integration tests for Voxtral modeling patches."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoxtralModelingPatchIntegration:
|
||||||
|
"""Test Voxtral modeling patch integration."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_voxtral_conditional_generation_patch(self):
|
||||||
|
"""Test that Voxtral conditional generation patch can be applied."""
|
||||||
|
try:
|
||||||
|
from transformers.models.voxtral.modeling_voxtral import (
|
||||||
|
VoxtralForConditionalGeneration,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("VoxtralForConditionalGeneration not available")
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.models.voxtral.modeling import (
|
||||||
|
patch_voxtral_conditional_generation_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_forward = VoxtralForConditionalGeneration.forward
|
||||||
|
|
||||||
|
# Apply patch and get unpatch function
|
||||||
|
unpatch_fn = patch_voxtral_conditional_generation_forward()
|
||||||
|
|
||||||
|
# Verify patch was applied
|
||||||
|
assert VoxtralForConditionalGeneration.forward != original_forward, (
|
||||||
|
"forward method was not patched"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the method is still callable
|
||||||
|
assert callable(VoxtralForConditionalGeneration.forward), (
|
||||||
|
"Patched method is not callable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test unpatch function
|
||||||
|
unpatch_fn()
|
||||||
|
assert VoxtralForConditionalGeneration.forward == original_forward, (
|
||||||
|
"unpatch function did not restore original method"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user