Compare commits
2 Commits
coderabbit
...
fix/diffus
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08c8f3f22f | ||
|
|
76f0fe2621 |
30
.github/workflows/tests.yml
vendored
30
.github/workflows/tests.yml
vendored
@@ -66,12 +66,12 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# - name: Restore Cache from S3
|
- name: Restore Cache from S3
|
||||||
# id: hf-cache-restore-s3
|
id: hf-cache-restore-s3
|
||||||
# run: |
|
run: |
|
||||||
# mkdir -p ~/.cache/huggingface/hub
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
#
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -113,13 +113,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
df -h
|
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
df -h
|
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
|
||||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
@@ -149,12 +145,12 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# - name: Restore Cache from S3
|
- name: Restore Cache from S3
|
||||||
# id: hf-cache-restore-s3
|
id: hf-cache-restore-s3
|
||||||
# run: |
|
run: |
|
||||||
# mkdir -p ~/.cache/huggingface/hub
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
#
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -192,7 +188,7 @@ jobs:
|
|||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Show HF cache
|
- name: Show HF cache
|
||||||
run: hf cache scan
|
run: huggingface-cli scan-cache
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ ARG BASE_VOLUME="/runpod-volume"
|
|||||||
ENV BASE_VOLUME=$BASE_VOLUME
|
ENV BASE_VOLUME=$BASE_VOLUME
|
||||||
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||||
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
|
||||||
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||||
|
|
||||||
COPY .runpod/src /src
|
COPY .runpod/src /src
|
||||||
|
|||||||
@@ -29,7 +29,7 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3).
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -253,6 +253,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||||
"set_pytorch_cuda_alloc_conf()"
|
"set_pytorch_cuda_alloc_conf()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ eval_sample_packing: true
|
|||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
warmup_steps: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
optimizer: adamw_8bit
|
optimizer: adamw_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
@@ -44,7 +44,7 @@ resume_from_checkpoint:
|
|||||||
sdp_attention: true
|
sdp_attention: true
|
||||||
|
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
save_strategy: best
|
save_strategy: epoch
|
||||||
eval_strategy: epoch
|
eval_strategy: epoch
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
# Finetune Ministral with Axolotl
|
|
||||||
|
|
||||||
Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/ministral/ministral-small-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 8.76 GiB VRAM.
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
|
||||||
|
|
||||||
In addition, we do not support overriding tokens yet.
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
|
|
||||||
|
|
||||||
## Future Work
|
|
||||||
|
|
||||||
- Add parity to Preference Tuning, RL, etc.
|
|
||||||
- Add parity to other tokenizer configs like overriding tokens.
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: mistralai/Ministral-8B-Instruct-2410
|
|
||||||
|
|
||||||
# Enable to use mistral-common tokenizer
|
|
||||||
tokenizer_use_mistral_common: true
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
# Finetune Ministral3 with Axolotl
|
|
||||||
|
|
||||||
Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.
|
|
||||||
|
|
||||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
|
||||||
|
|
||||||
Note: This is still experimental given it is based on transformers v5 RC.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Swap to the Axolotl transformers v5 branch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml
|
|
||||||
|
|
||||||
git fetch
|
|
||||||
git checkout transformers-v5
|
|
||||||
|
|
||||||
# Install packages for transformers v5
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Run the fine-tuning:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train ministral3-3b-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
### Thinking
|
|
||||||
|
|
||||||
Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
|
||||||
|
|
||||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
|
||||||
|
|
||||||
### Vision
|
|
||||||
|
|
||||||
Ministral3 2512 model also supports vision capabilities.
|
|
||||||
|
|
||||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
|
||||||
|
|
||||||
In addition, we do not support overriding tokens yet.
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
|
|
||||||
|
|
||||||
## Future Work
|
|
||||||
|
|
||||||
- Add parity to Preference Tuning, RL, etc.
|
|
||||||
- Add parity to other tokenizer configs like overriding tokens.
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
|
||||||
|
|
||||||
# Enable to use mistral-common tokenizer
|
|
||||||
tokenizer_use_mistral_common: true
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
# Ministral3 2512 Thinking Fine-tuning
|
|
||||||
|
|
||||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
Before starting, ensure you have:
|
|
||||||
- Installed Axolotl (see [main README](../README.md))
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
Run the thinking model fine-tuning:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 4.76 GiB VRAM.
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
|
||||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
|
||||||
|
|
||||||
## Dataset Format
|
|
||||||
|
|
||||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
|
||||||
|
|
||||||
Example format:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "thinking",
|
|
||||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Options
|
|
||||||
|
|
||||||
The `thinking` section supports an optional `closed` parameter:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "thinking",
|
|
||||||
"thinking": "Internal reasoning here...",
|
|
||||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
|
||||||
|
|
||||||
# Enable to use mistral-common tokenizer
|
|
||||||
tokenizer_use_mistral_common: true
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: Nanobit/text-think-2k-test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# Ministral3 2512 Vision Fine-tuning
|
|
||||||
|
|
||||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
Before starting, ensure you have:
|
|
||||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install the required vision lib:
|
|
||||||
```bash
|
|
||||||
pip install 'mistral-common[opencv]==1.8.6'
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Download the example dataset image:
|
|
||||||
```bash
|
|
||||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Run the fine-tuning:
|
|
||||||
```bash
|
|
||||||
axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
Key differences from text-only model:
|
|
||||||
- Multi-modal dataset format required
|
|
||||||
- Sample packing not supported
|
|
||||||
|
|
||||||
## Dataset Format
|
|
||||||
|
|
||||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
|
||||||
|
|
||||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
|
||||||
{"role": "user", "content": [
|
|
||||||
{ "type": "text", "text": "What's in this image?"},
|
|
||||||
{"type": "image", "path": "path/to/image.jpg" }
|
|
||||||
]},
|
|
||||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- Sample Packing is not supported for multi-modality training currently.
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
|
|
||||||
# Enable to use mistral-common tokenizer
|
|
||||||
tokenizer_use_mistral_common: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# sample dataset below requires downloading image in advance
|
|
||||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
|
||||||
datasets:
|
|
||||||
- path: Nanobit/text-vision-2k-test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -6,16 +6,24 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
|
pip3 install packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
|
||||||
|
# Install Cut Cross Entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
### TIPS
|
### TIPS
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: google/gemma-3-12b-it
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/out_gemma/
|
|
||||||
|
|
||||||
sequence_len: 8096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 4e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: google/gemma-3-12b-it
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out_gemma/
|
|
||||||
|
|
||||||
sequence_len: 8096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
qat:
|
|
||||||
activation_dtype: nvfp4
|
|
||||||
weight_dtype: nvfp4
|
|
||||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 4e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: google/gemma-3-12b-it
|
|
||||||
# Math finetuning configuration for Gemma3-12B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/out_math_gemma/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 8
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 3e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: google/gemma-3-12b-it
|
|
||||||
# Math finetuning configuration for Gemma3-12B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out_math_gemma/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
qat:
|
|
||||||
activation_dtype: nvfp4
|
|
||||||
weight_dtype: nvfp4
|
|
||||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 8
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 3e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: google/gemma-3-27b-it
|
|
||||||
# Math finetuning configuration for Gemma3-27B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/out_math_gemma27/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-6
|
|
||||||
eta_min: 7e-7
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
base_model: google/gemma-3-27b-it
|
|
||||||
# Math finetuning configuration for Gemma3-27B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out_math_gemma27/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
qat:
|
|
||||||
activation_dtype: nvfp4
|
|
||||||
weight_dtype: nvfp4
|
|
||||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-6
|
|
||||||
eta_min: 7e-7
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-72B
|
|
||||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: qwen_25
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/out_math_72b/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 8
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-6
|
|
||||||
eta_min: 7e-7
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-72B
|
|
||||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: qwen_25
|
|
||||||
datasets:
|
|
||||||
- path: AI-MO/NuminaMath-CoT
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out_math_72b/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
qat:
|
|
||||||
activation_dtype: nvfp4
|
|
||||||
weight_dtype: nvfp4
|
|
||||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 8
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-6
|
|
||||||
eta_min: 7e-7
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-72B
|
|
||||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: qwen_25
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/out_qwen72b/
|
|
||||||
|
|
||||||
sequence_len: 8096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-72B
|
|
||||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
seed: 42
|
|
||||||
chat_template: qwen_25
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out_qwen72b/
|
|
||||||
|
|
||||||
sequence_len: 8096
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
qat:
|
|
||||||
activation_dtype: nvfp4
|
|
||||||
weight_dtype: nvfp4
|
|
||||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
|
||||||
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
# evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp_version: 2
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
# Finetune Qwen3 with Axolotl
|
|
||||||
|
|
||||||
[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/qwen3/32b-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### Chat template masking a few tokens off
|
|
||||||
|
|
||||||
If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
chat_template: qwen3
|
|
||||||
```
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- For inference, please check the official model card as it depends on your reasoning mode.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
# Finetune ArceeAI's Trinity with Axolotl
|
|
||||||
|
|
||||||
[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
|
||||||
|
|
||||||
2. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 24.9 GiB VRAM.
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
# CCE - N/A as of now
|
|
||||||
# plugins:
|
|
||||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
# flash_attention: true # Not supported
|
|
||||||
sdp_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -5,7 +5,7 @@ bitsandbytes==0.48.2
|
|||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
liger-kernel==0.6.4
|
liger-kernel==0.6.3
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.5
|
|||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
mistral-common==1.8.6
|
mistral-common==1.8.5
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"'
|
||||||
)
|
)
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -66,6 +66,7 @@ def parse_requirements(extras_require_map):
|
|||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
elif (major, minor) >= (2, 8):
|
elif (major, minor) >= (2, 8):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
|||||||
launch_training,
|
launch_training,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
@@ -45,7 +45,6 @@ def cli():
|
|||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
set_misc_env()
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Union
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.quantization import (
|
from axolotl.utils.quantization import (
|
||||||
TorchAOQuantDType,
|
TorchAOQuantDType,
|
||||||
@@ -66,11 +66,6 @@ def do_quantize(
|
|||||||
|
|
||||||
LOG.info(f"Loading model from {model_path}.")
|
LOG.info(f"Loading model from {model_path}.")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
processor = None
|
|
||||||
if cfg.is_multimodal:
|
|
||||||
processor = load_processor(cfg, tokenizer)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -112,10 +107,6 @@ def do_quantize(
|
|||||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
if processor:
|
|
||||||
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
|
|
||||||
processor.save_pretrained(str(Path(output_dir) / "quantized"))
|
|
||||||
|
|
||||||
if hub_model_id:
|
if hub_model_id:
|
||||||
hub_model_id = (
|
hub_model_id = (
|
||||||
hub_model_id.rstrip("-")
|
hub_model_id.rstrip("-")
|
||||||
@@ -123,8 +114,6 @@ def do_quantize(
|
|||||||
)
|
)
|
||||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||||
tokenizer.push_to_hub(hub_model_id)
|
tokenizer.push_to_hub(hub_model_id)
|
||||||
if processor:
|
|
||||||
processor.push_to_hub(hub_model_id)
|
|
||||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||||
|
|
||||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
|
|||||||
@@ -17,5 +17,4 @@ MOE_ARCH_BLOCK = {
|
|||||||
"deepseek_v3": "DeepseekV3MoE",
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
"afmoe": "AfmoeMoE",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -631,10 +631,8 @@ class AxolotlTrainer(
|
|||||||
logs["tokens_per_second_per_gpu"] = round(
|
logs["tokens_per_second_per_gpu"] = round(
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
hasattr(self.state, "total_tokens")
|
if hasattr(self.state, "total_tokens"):
|
||||||
and self.state.total_tokens is not None
|
|
||||||
):
|
|
||||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -61,8 +61,6 @@ plugins:
|
|||||||
- llama4
|
- llama4
|
||||||
- llama4_text
|
- llama4_text
|
||||||
- llava
|
- llava
|
||||||
- ministral
|
|
||||||
- ministral3
|
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
- mixtral
|
- mixtral
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -179,17 +179,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
logprobs = prompt.pop(self.logprobs_field)
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
tokenized_prompt[self.logprobs_field] = logprobs
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
|
|
||||||
# let subclasses add fields before transform
|
|
||||||
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
|
|
||||||
|
|
||||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
return tokenized_prompt
|
|
||||||
|
|
||||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
|
||||||
"""
|
|
||||||
Hook for subclasses to prepare additional KD fields before transform
|
|
||||||
"""
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
@@ -292,13 +283,14 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
"""
|
target_token_ids = prompt.get("target_token_ids", None)
|
||||||
Add pre-tokenized target_token_ids for v2 format
|
|
||||||
"""
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
target_token_ids = original_prompt.pop("target_token_ids", None)
|
|
||||||
if target_token_ids is not None:
|
if target_token_ids is not None:
|
||||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
KD trainer
|
KD trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||||
@@ -62,7 +60,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
if columns_to_add:
|
if columns_to_add:
|
||||||
self._signature_columns += columns_to_add
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
@override
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@@ -82,22 +79,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
):
|
):
|
||||||
del inputs["attention_mask"]
|
del inputs["attention_mask"]
|
||||||
|
|
||||||
if num_items_in_batch is None and "labels" in inputs:
|
|
||||||
num_items_in_batch = (inputs["labels"] != -100).sum().item()
|
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None:
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
inputs = {**inputs, **loss_kwargs}
|
inputs = {**inputs, **loss_kwargs}
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
return outputs[0]
|
||||||
if isinstance(outputs, dict):
|
|
||||||
loss = outputs["loss"]
|
|
||||||
elif isinstance(outputs, tuple):
|
|
||||||
loss = outputs[0]
|
|
||||||
else:
|
|
||||||
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
|
|||||||
@@ -142,12 +142,9 @@ def load_lora(
|
|||||||
):
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
|
|
||||||
model_kwargs: Any = {}
|
|
||||||
if cfg.peft_autocast_adapter_dtype is not None:
|
|
||||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||||
|
model_kwargs: Any = {}
|
||||||
if cfg.lora_on_cpu:
|
if cfg.lora_on_cpu:
|
||||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||||
model_kwargs["device_map"] = {"": "cpu"}
|
model_kwargs["device_map"] = {"": "cpu"}
|
||||||
@@ -158,7 +155,7 @@ def load_lora(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config, **model_kwargs)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -52,9 +52,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"olmo",
|
"olmo",
|
||||||
"olmo2",
|
"olmo2",
|
||||||
"olmo3",
|
"olmo3",
|
||||||
"ministral",
|
|
||||||
"ministral3",
|
|
||||||
"afmoe",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
images=None,
|
images=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
real_last_index=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build a prompt from a conversation.
|
Build a prompt from a conversation.
|
||||||
@@ -115,9 +114,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
if tools:
|
if tools:
|
||||||
chat_template_kwargs["tools"] = tools
|
chat_template_kwargs["tools"] = tools
|
||||||
|
|
||||||
if real_last_index:
|
|
||||||
chat_template_kwargs["real_last_index"] = real_last_index
|
|
||||||
|
|
||||||
if self.processor:
|
if self.processor:
|
||||||
if not callable(self.processor):
|
if not callable(self.processor):
|
||||||
raise TypeError("Processor must be callable")
|
raise TypeError("Processor must be callable")
|
||||||
@@ -635,17 +631,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||||
turns_with_content = turns[: turn_idx + 1]
|
turns_with_content = turns[: turn_idx + 1]
|
||||||
|
|
||||||
real_last_index = len(turns) - 1
|
|
||||||
|
|
||||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||||
dummy_ids = self.prompter.build_prompt(
|
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||||
turns_with_empty, tools=tools, real_last_index=real_last_index
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
# Generate the conversation up to the turn, with final turn included
|
# Generate the conversation up to the turn, with final turn included
|
||||||
full_ids = self.prompter.build_prompt(
|
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||||
turns_with_content, tools=tools, real_last_index=real_last_index
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if not full_ids or not dummy_ids:
|
if not full_ids or not dummy_ids:
|
||||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||||
|
|||||||
@@ -41,27 +41,14 @@ def get_pytorch_version() -> tuple[int, int, int]:
|
|||||||
|
|
||||||
|
|
||||||
def set_pytorch_cuda_alloc_conf():
|
def set_pytorch_cuda_alloc_conf():
|
||||||
"""Set up CUDA allocation config"""
|
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
|
||||||
torch_version = torch.__version__.split(".")
|
torch_version = torch.__version__.split(".")
|
||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
config_value = "expandable_segments:True,roundup_power2_divisions:16"
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
if (
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
torch_major == 2
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||||
and torch_minor >= 9
|
"expandable_segments:True,roundup_power2_divisions:16"
|
||||||
and os.getenv("PYTORCH_ALLOC_CONF") is None
|
)
|
||||||
):
|
|
||||||
os.environ["PYTORCH_ALLOC_CONF"] = config_value
|
|
||||||
elif (
|
|
||||||
torch_major == 2
|
|
||||||
and torch_minor >= 2
|
|
||||||
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
|
|
||||||
):
|
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
|
|
||||||
|
|
||||||
|
|
||||||
def set_misc_env():
|
|
||||||
if os.getenv("XFORMERS_IGNORE_FLASH_VERSION_CHECK") is None:
|
|
||||||
os.environ["XFORMERS_IGNORE_FLASH_VERSION_CHECK"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def get_not_null(value, default=None):
|
def get_not_null(value, default=None):
|
||||||
|
|||||||
@@ -15,12 +15,6 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
{#- Determine the real last index: use provided value or default to messages length - 1 #}
|
|
||||||
{%- if real_last_index is defined and real_last_index is not none %}
|
|
||||||
{%- set ns.real_last_index = real_last_index %}
|
|
||||||
{%- else %}
|
|
||||||
{%- set ns.real_last_index = messages|length - 1 %}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages[::-1] %}
|
{%- for message in messages[::-1] %}
|
||||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||||
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||||
@@ -43,7 +37,7 @@
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- if loop.index0 > ns.last_query_index %}
|
{%- if loop.index0 > ns.last_query_index %}
|
||||||
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
|
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
|||||||
@@ -203,7 +203,6 @@ def wrap_streaming_dataset(
|
|||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
multipack_attn=multipack_attn,
|
multipack_attn=multipack_attn,
|
||||||
bin_size=cfg.sample_packing_bin_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
|
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
|
||||||
@@ -255,7 +254,6 @@ def encode_packed_streaming(
|
|||||||
collate_fn,
|
collate_fn,
|
||||||
ds_wrapper: Callable,
|
ds_wrapper: Callable,
|
||||||
examples: Dict[str, List],
|
examples: Dict[str, List],
|
||||||
bin_size: int,
|
|
||||||
max_seq_length: int = 2048,
|
max_seq_length: int = 2048,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
multipack_attn: Optional[bool] = True,
|
multipack_attn: Optional[bool] = True,
|
||||||
@@ -280,7 +278,6 @@ def encode_packed_streaming(
|
|||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=1,
|
num_processes=1,
|
||||||
bin_size=bin_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -180,20 +180,15 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
def handle_long_seq_in_dataset(
|
def handle_long_seq_in_dataset(
|
||||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
Remove or truncate sequences that exceed the configured maximum length from a dataset.
|
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
|
dataset: Dataset to filter.
|
||||||
sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
|
sequence_len: Maximum length for sequences to keep
|
||||||
cfg (DictDefault): Configuration object with keys:
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
|
|
||||||
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
|
|
||||||
- dataset_num_proc: number of processes to use for non-streaming datasets.
|
|
||||||
- is_preprocess: when true, bypasses cached preprocessing during filtering.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
|
Filtered dataset with long sequences removed.
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
hasattr(dataset, "column_names")
|
hasattr(dataset, "column_names")
|
||||||
@@ -211,13 +206,10 @@ def handle_long_seq_in_dataset(
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
|
||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
raise_on_drop=excess_length_strategy == "raise",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
@@ -238,6 +230,7 @@ def handle_long_seq_in_dataset(
|
|||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
if excess_length_strategy == "truncate":
|
if excess_length_strategy == "truncate":
|
||||||
process_fn = functools.partial(
|
process_fn = functools.partial(
|
||||||
truncate_long_seq,
|
truncate_long_seq,
|
||||||
|
|||||||
@@ -80,9 +80,6 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
|||||||
) -> str | list[int]:
|
) -> str | list[int]:
|
||||||
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
|
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
|
||||||
|
|
||||||
# pop unnecessary kwarg for mistral
|
|
||||||
kwargs.pop("real_last_index", None)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if add_generation_prompt:
|
if add_generation_prompt:
|
||||||
self._set_mode(ValidationMode.serving)
|
self._set_mode(ValidationMode.serving)
|
||||||
@@ -221,10 +218,3 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
|||||||
model_input_names=model_input_names,
|
model_input_names=model_input_names,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]:
|
|
||||||
"""
|
|
||||||
Patches to remove save_jinja_files from being passed onwards.
|
|
||||||
"""
|
|
||||||
kwargs.pop("save_jinja_files", None)
|
|
||||||
return super().save_pretrained(*args, **kwargs)
|
|
||||||
|
|||||||
@@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
batch_size: int, # Number of bins per batch
|
batch_size: int, # Number of bins per batch
|
||||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
bin_size: int, # The max number of samples that can be packed in a single bin
|
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 4, # Number of times to estimate batch count
|
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
num_processes: int | None = None, # Number of processes for parallel packing
|
num_processes: int | None = None, # Number of processes for parallel packing
|
||||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||||
mp_start_method: str = "fork",
|
mp_start_method: str = "fork",
|
||||||
@@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths,
|
lengths,
|
||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size or self.batch_max_len,
|
bin_size=self.bin_size,
|
||||||
num_processes=min(4, num_processes) if num_processes else 4,
|
num_processes=min(4, num_processes) if num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
mp_start_method=self.mp_start_method,
|
mp_start_method=self.mp_start_method,
|
||||||
|
|||||||
@@ -109,12 +109,6 @@ class LoraConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
peft_autocast_adapter_dtype: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -201,33 +201,16 @@ def add_pose_position_ids(
|
|||||||
|
|
||||||
|
|
||||||
def add_length(sample):
|
def add_length(sample):
|
||||||
"""
|
|
||||||
Set the "length" field on a sample to the number of input tokens.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
sample (Mapping-like): A sample containing an "input_ids" sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
|
|
||||||
"""
|
|
||||||
sample["length"] = len(sample["input_ids"])
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
"""
|
"""
|
||||||
Return whether a sample (single or batched) should be kept based on sequence length constraints.
|
Drop samples whose sequence length is either too long (> sequence_len)
|
||||||
|
or too short (< min_sequence_len).
|
||||||
|
|
||||||
Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
|
||||||
Parameters:
|
|
||||||
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
|
|
||||||
sequence_len (int): Maximum allowed sequence length (inclusive).
|
|
||||||
min_sequence_len (int): Minimum allowed sequence length (inclusive).
|
|
||||||
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
|
|
||||||
"""
|
"""
|
||||||
min_sequence_len = min_sequence_len or 2
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
@@ -242,20 +225,12 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=F
|
|||||||
if isinstance(input_ids[0], int):
|
if isinstance(input_ids[0], int):
|
||||||
# Single example (input_ids is a list of int)
|
# Single example (input_ids is a list of int)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
if raise_on_drop and length > sequence_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
|
||||||
)
|
|
||||||
return min_sequence_len <= length <= sequence_len
|
return min_sequence_len <= length <= sequence_len
|
||||||
|
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
results = []
|
results = []
|
||||||
for seq in input_ids:
|
for seq in input_ids:
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
if raise_on_drop and length > sequence_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
|
||||||
)
|
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
Test for KD chat template strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatTemplateStrategyWithKDv2:
|
|
||||||
"""Test v2 strategy correctly handles target_token_ids"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def v2_strategy(self):
|
|
||||||
"""Create v2 strategy instance with mocked dependencies"""
|
|
||||||
# Mock prompter
|
|
||||||
mock_prompter = Mock()
|
|
||||||
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
|
|
||||||
mock_prompter.chat_template_msg_variables = ["role", "content"]
|
|
||||||
mock_prompter.chat_template = "{{ messages }}"
|
|
||||||
|
|
||||||
# Mock tokenizer
|
|
||||||
mock_tokenizer = Mock()
|
|
||||||
mock_tokenizer.pad_token_id = 0
|
|
||||||
mock_tokenizer.eos_token_id = 2
|
|
||||||
mock_tokenizer.bos_token_id = 1
|
|
||||||
mock_tokenizer.eos_token = "<|endoftext|>"
|
|
||||||
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
|
|
||||||
mock_tokenizer.encode = Mock(return_value=[2])
|
|
||||||
|
|
||||||
return ChatTemplateStrategyWithKDv2(
|
|
||||||
prompter=mock_prompter,
|
|
||||||
tokenizer=mock_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
logprobs_field="logprobs",
|
|
||||||
gen_temperature=1.0,
|
|
||||||
kd_temperature=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
|
|
||||||
"""
|
|
||||||
Test that v2's _prepare_kd_fields hook adds target_token_ids.
|
|
||||||
|
|
||||||
Validates the Template Method pattern fix where v2 overrides
|
|
||||||
the hook to add target_token_ids before transform.
|
|
||||||
"""
|
|
||||||
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
|
||||||
original = {"target_token_ids": [[10, 20], [30, 40]]}
|
|
||||||
|
|
||||||
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
|
||||||
|
|
||||||
assert "target_token_ids" in result
|
|
||||||
assert result["target_token_ids"] == [[10, 20], [30, 40]]
|
|
||||||
|
|
||||||
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
|
|
||||||
"""Test hook handles missing target_token_ids gracefully"""
|
|
||||||
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
|
||||||
original = {}
|
|
||||||
|
|
||||||
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
|
||||||
|
|
||||||
assert "target_token_ids" not in result
|
|
||||||
|
|
||||||
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
|
|
||||||
"""
|
|
||||||
Test v2's transform fails without target_token_ids.
|
|
||||||
|
|
||||||
Validates the bug fix - transform expects target_token_ids
|
|
||||||
to be added by the hook.
|
|
||||||
"""
|
|
||||||
sample = {
|
|
||||||
"input_ids": [1, 10, 20, 30, 2],
|
|
||||||
"labels": [1, 10, 20, 30, 2],
|
|
||||||
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(KeyError, match="target_token_ids"):
|
|
||||||
v2_strategy.transform_logprobs(sample)
|
|
||||||
Reference in New Issue
Block a user