Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
21ba1cd3f1 wire up squash_position_ids 2025-08-23 16:21:28 -04:00
357 changed files with 12414 additions and 16119 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615,B102,B110
skips = B101,B615

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: false
auto_incremental_review: true
chat:
auto_reply: true

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -36,11 +36,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -115,11 +110,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -179,12 +169,6 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,15 +36,15 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,13 +298,6 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -325,7 +318,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
@@ -342,10 +334,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,12 +10,22 @@ repos:
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: ruff
args: [--fix, --select, I]
- id: ruff-format
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
hooks:

15
.pylintrc Normal file
View File

@@ -0,0 +1,15 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -1,6 +1,6 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Open Source LLM Post-Training"
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"

View File

@@ -5,9 +5,6 @@
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p align="center">
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
</p>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
@@ -20,7 +17,6 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -53,21 +49,20 @@
## ✨ Overview
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Axolotl is a tool designed to streamline post-training for various AI models.
Features:
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start - LLM Fine-tuning in Minutes
## 🚀 Quick Start
**Requirements**:
@@ -75,10 +70,6 @@ Features:
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation
#### Using pip
@@ -164,7 +155,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Open Source LLM Post-Training},
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.streaming
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,7 +272,6 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd

View File

@@ -2,6 +2,8 @@
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -61,7 +63,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(

View File

@@ -1,5 +1,7 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -57,8 +59,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
@@ -69,4 +70,4 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}

View File

@@ -63,6 +63,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}

View File

@@ -51,11 +51,3 @@ axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -11,7 +11,6 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -47,6 +47,7 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
@@ -123,6 +124,7 @@ class QuartoGenerator:
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
@@ -316,6 +318,7 @@ class QuartoGenerator:
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
@@ -500,7 +503,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception:
except Exception: # pylint: disable=broad-exception-caught
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
@@ -604,7 +607,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e:
except Exception as e: # pylint: disable=broad-exception-caught
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)

View File

@@ -1,120 +0,0 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -1,10 +0,0 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -20,13 +20,7 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml

View File

@@ -1,68 +0,0 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -106,16 +106,6 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### Thinking and chat_template masking conflict
OpenAIs Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotls `chat_template` masking.
If your dataset has `thinking` content mid-turn, there are two paths we recommend:
- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).
- Adjust your dataset to only have `thinking` content in the last turn.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -1,85 +0,0 @@
# Finetune HunYuan with Axolotl
Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml
```
This config uses about 4.7 GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Dataset
HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.
```python
# fast think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\n\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
# slow think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\nThe user is asking about the color of the sun. I need to ...\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
```
### TIPS
- For inference, the official Tencent team recommends
```json
{
"do_sample": true,
"top_k": 20,
"top_p": 0.8,
"repetition_penalty": 1.05,
"temperature": 0.7
}
```
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: tencent/Hunyuan-0.5B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,64 +0,0 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,18 +15,20 @@ liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: false
sequence_len: 8192
flash_attention: true
sample_packing: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
@@ -65,7 +67,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
@@ -74,6 +76,6 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,56 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 128
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
warmup_ratio: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
sdp_attention: true
bf16: auto
tf32: true
logging_steps: 1
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,59 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
warmup_steps: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
sdp_attention: true
logging_steps: 1
save_strategy: best
eval_strategy: epoch
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -18,13 +18,7 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml

View File

@@ -1,44 +0,0 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -1,54 +0,0 @@
# Finetune ByteDance's Seed-OSS with Axolotl
[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml
```
This config uses about 27.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [ByteDance Seed Website](https://seed.bytedance.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,56 +0,0 @@
base_model: ByteDance-Seed/Seed-OSS-36B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,50 +0,0 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -1,57 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,55 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -22,9 +22,6 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# audio
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:

View File

@@ -26,34 +26,3 @@ include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -2,7 +2,8 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
triton>=3.0.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -13,7 +14,7 @@ packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
transformers==4.55.3
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0
@@ -64,7 +65,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not field_messages:
raise ValueError(
f"No conversation field found in dataset: {', '.join(feature_keys)}"
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["role"]:
raise ValueError(
f"No role field found in messages: {', '.join(message_fields)}"
f'No role field found in messages: {", ".join(message_fields)}'
)
for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["content"]:
raise ValueError(
f"No content field found in messages: {', '.join(message_fields)}"
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
)

View File

@@ -1,10 +1,11 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]

View File

@@ -64,9 +64,7 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
@@ -127,7 +125,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.5",
"deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -162,7 +160,6 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -14,13 +14,9 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
default=None,
metadata={
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@@ -115,7 +111,6 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():

View File

@@ -7,8 +7,6 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -40,15 +38,8 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}

View File

@@ -1,48 +0,0 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -1,9 +0,0 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -1,71 +0,0 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
@@ -130,6 +130,7 @@ class ModalCloud(Cloud):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
@@ -176,8 +177,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
*args,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
@@ -186,7 +187,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self):
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
@@ -276,7 +277,7 @@ def _train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs,
**kwargs, # pylint: disable=unused-argument
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:

View File

@@ -210,7 +210,7 @@ def load_cfg(
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except:
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)

View File

@@ -28,7 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -49,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -14,14 +14,10 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
render_html,
run_diffusion,
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -36,11 +32,10 @@ def get_multi_line_input() -> str:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
print("=" * 80)
instruction = ""
for line in sys.stdin:
instruction += line
instruction += line # pylint: disable=consider-using-join
return instruction
@@ -51,9 +46,9 @@ def do_inference(
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat
template is (optionally) applied, and the model specified in the `axolotl` config is
used to generate completions according to a default generation config.
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -69,31 +64,17 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
print("=" * 80)
print("Commands:")
print(":complete N -> completion mode with N tokens (default 64)")
print(":mask R -> random masking with ratio R (0.01.0)")
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
@@ -123,19 +104,9 @@ def do_inference(
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 80)
print("=" * 40)
model.eval()
with torch.no_grad():
if is_diffusion:
diffusion_inference(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
)
continue
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
@@ -158,7 +129,7 @@ def do_inference(
generation_config=generation_config,
streamer=streamer,
)
print("=" * 80)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -188,37 +159,15 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
launch_diffusion_gradio_ui(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompter_module=prompter_module,
chat_template_str=chat_template_str,
)
return
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
@@ -303,7 +252,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -1,5 +1,7 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os
import subprocess # nosec B404
from typing import Literal, Optional

View File

@@ -43,10 +43,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor):
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
@@ -59,10 +59,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict(
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
no_dist=True,
)
@@ -191,7 +191,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -35,20 +35,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
return
@@ -83,7 +73,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception: # nosec B110
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
pass
# fmt: on
@@ -105,7 +95,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)

View File

@@ -5,17 +5,12 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from transformers import AutoModelForCausalLM
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
@@ -48,13 +43,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("base_model") or cfg.output_dir
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -62,15 +57,10 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}.")
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
@@ -80,21 +70,11 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model(
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
@@ -104,16 +84,5 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -59,7 +59,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -65,7 +65,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type is bool:
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -103,7 +103,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type is bool:
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(

View File

@@ -1,375 +0,0 @@
"""Helpers for diffusion-mode inference in CLI and Gradio."""
from __future__ import annotations
import gradio as gr
import torch
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
from axolotl.utils.dict import DictDefault
def diffusion_inference(
model,
tokenizer,
cfg,
prompt: str,
chat_template_str: str | None = None,
):
"""Diffusion inference helper method."""
mode = "random"
completion_tokens = 0
target_mask_ratio = None
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
if cleaned:
prompt = cleaned
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=mode,
target_mask_ratio=target_mask_ratio,
completion_tokens=completion_tokens,
)
masked_text = info["masked_text"]
mask_ratio = info["mask_ratio"]
generated_ids = info["generated_ids"]
masked_positions = info["masked_positions"]
orig_ids = info["orig_ids"]
# Display with masked preview and colored diff
if masked_text is not None and mask_ratio is not None:
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
if generated_ids is not None:
# Compute per-token style
styles: list[str] = []
for i, tid in enumerate(generated_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green") # correct fill
elif i < len(orig_ids):
styles.append("red") # incorrect fill
else:
styles.append("normal") # appended
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
# Group contiguous spans by style
styled_spans: list[tuple[str, int, int]] = []
if generated_ids:
current_style = styles[0]
start = 0
for i in range(1, len(generated_ids)):
s = styles[i]
if s != current_style:
styled_spans.append((current_style, start, i))
current_style, start = s, i
styled_spans.append((current_style, start, len(generated_ids)))
out_parts = []
for style_name, a, b in styled_spans:
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
out_parts.append(chunk_text)
print("Generated:\n" + "".join(out_parts))
else:
print("Generated:\n(no output)")
def _parse_commands(text: str):
"""
Parse leading diffusion commands.
Supported at start of input (can be chained):
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R in [0, 1]
"""
tokens = text.strip().split()
i = 0
mode = "random"
completion_tokens = 0
target_mask_ratio = None
consumed = 0
while i < len(tokens) and tokens[i].startswith(":"):
cmd = tokens[i]
i += 1
consumed = i
if cmd == ":complete":
mode = "completion"
if i < len(tokens):
try:
completion_tokens = int(tokens[i])
i += 1
consumed = i
except Exception:
completion_tokens = 64
else:
completion_tokens = 64
elif cmd == ":mask":
mode = "random"
if i < len(tokens):
try:
target_mask_ratio = float(tokens[i])
i += 1
consumed = i
except Exception:
target_mask_ratio = None
else:
i -= 1
consumed = i
break
cleaned = " ".join(tokens[consumed:])
return mode, completion_tokens, target_mask_ratio, cleaned
def run_diffusion(
*,
model,
tokenizer,
cfg: DictDefault,
prompt: str,
chat_template_str: str | None,
mode: str = "random",
target_mask_ratio: float | None = None,
completion_tokens: int = 0,
):
"""Run a single diffusion generation and return a structured result dict."""
if chat_template_str:
batch = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
seq = batch["input_ids"].to(cfg.device)
gen_mode = "completion" if mode == "completion" else "random"
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
result = generate(
model,
tokenizer,
original_sequence=seq[:1],
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
temperature=cfg.diffusion.generation_temperature,
mask_token_id=int(mask_token_id),
mode=gen_mode, # type: ignore[arg-type]
completion_tokens=comp_tokens,
target_mask_ratio=target_mask_ratio,
)
masked_text = result.get("masked") if isinstance(result, dict) else None
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
masked_positions = (
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
)
orig_ids = seq[0].detach().cpu().tolist()
return {
"masked_text": masked_text,
"mask_ratio": mask_ratio,
"generated_ids": generated_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids,
}
def render_html(
*,
generated_ids: list[int] | None,
orig_ids: list[int],
masked_positions: set[int],
tokenizer,
) -> str:
"""Render HTML visualizing diffusion outputs."""
if not generated_ids:
return "<pre>Generated:\n(no output)</pre>"
def _style_for(i: int, tid: int) -> str:
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
return "green"
if i < len(orig_ids):
return "red"
return "normal"
same = i < len(orig_ids) and tid == orig_ids[i]
return "dim" if same else "normal"
# Group contiguous spans by style to reduce HTML size
spans: list[tuple[str, int, int]] = []
if generated_ids:
cur = _style_for(0, generated_ids[0])
start = 0
for i in range(1, len(generated_ids)):
s = _style_for(i, generated_ids[i])
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(generated_ids)))
html_parts = []
for style_name, a, b in spans:
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
elif style_name == "red":
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
elif style_name == "dim":
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
else:
html_parts.append(txt)
legend = (
'<div style="font-size:0.9em;margin-bottom:4px">'
'<span style="color:#2e7d32">correct</span>, '
'<span style="color:#c62828">incorrect</span>, '
'<span style="opacity:0.6">unchanged</span>'
"</div>"
)
return (
legend
+ '<pre style="white-space:pre-wrap">Generated:\n'
+ "".join(html_parts)
+ "</pre>"
)
def launch_diffusion_gradio_ui(
*,
model,
tokenizer,
cfg: DictDefault,
prompter_module=None,
chat_template_str: str | None = None,
):
"""Build and launch a simple Gradio UI for diffusion inference."""
with gr.Blocks(
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
) as demo:
gr.Markdown(
"""
## Axolotl Diffusion Inference
- Mode "Random" masks tokens at a target ratio and fills them.
- Mode "Completion" appends N masked tokens at the end and fills them.
"""
)
with gr.Row():
mode = gr.Radio(
choices=["random", "completion"],
value="random",
label="Mode",
)
mask_ratio = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.4,
label="Mask ratio (random mode)",
interactive=True,
)
completion_tokens = gr.Number(
value=64,
precision=0,
label="Completion tokens (completion mode)",
interactive=True,
visible=False,
)
instruction = gr.Textbox(label="Instruction", lines=6)
run_btn = gr.Button("Generate")
masked_preview = gr.Textbox(label="Masked preview", lines=6)
html_out = gr.HTML(label="Generated")
def _toggle_controls(selected_mode: str):
return (
gr.update(visible=(selected_mode == "random")),
gr.update(visible=(selected_mode == "completion")),
)
mode.change(
_toggle_controls,
inputs=[mode],
outputs=[mask_ratio, completion_tokens],
)
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
if not instruction_text:
return "", "<pre>Generated:\n(no output)</pre>"
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(
instruction=instruction_text.strip("\n")
)
)
else:
prompt = instruction_text.strip()
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=selected_mode,
target_mask_ratio=mratio if selected_mode == "random" else None,
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
)
masked_text = info.get("masked_text")
mask_ratio_val = info.get("mask_ratio")
generated_ids = info.get("generated_ids")
masked_positions = info.get("masked_positions") or set()
orig_ids = info.get("orig_ids") or []
preview = (
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
if masked_text is not None and mask_ratio_val is not None
else ""
)
html = render_html(
generated_ids=generated_ids,
orig_ids=orig_ids,
masked_positions=masked_positions,
tokenizer=tokenizer,
)
return preview, html
run_btn.click(
_gen,
inputs=[instruction, mode, mask_ratio, completion_tokens],
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)

View File

@@ -49,10 +49,7 @@ def generate_sweep_configs(
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
@@ -61,7 +58,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo, strict=False):
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)

View File

@@ -95,6 +95,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",

View File

@@ -39,7 +39,7 @@ def do_vllm_serve(
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,6 +68,7 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
@@ -55,11 +54,13 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -67,7 +67,9 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
def convert(
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -84,7 +84,9 @@ def create_causal_mask(
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
@@ -101,7 +103,9 @@ def create_causal_mask(
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (

View File

@@ -24,7 +24,9 @@ from pathlib import Path
from typing import Any
import torch
from transformers import TrainerCallback
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -42,7 +44,7 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC):
@@ -258,14 +260,14 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import (
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import (
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
@@ -412,8 +414,12 @@ class TrainerBuilderBase(abc.ABC):
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.accumulated_cache_size_limit = 256
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
@@ -510,7 +516,6 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -10,7 +10,6 @@ import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -36,7 +35,6 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -76,12 +74,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -348,22 +340,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls(
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = None
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
@@ -395,11 +385,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
@@ -417,9 +406,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
@@ -490,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
if self.cfg.squash_position_ids:
kwargs["squash_position_ids"] = True
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:

View File

@@ -168,14 +168,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls(
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = None
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args, trainer_kwargs

View File

@@ -10,7 +10,7 @@ from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message

View File

@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
Message roles for the system, user, assistant, and tools
"""
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
ipython = (
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # nosec B105
text = "text"
image = "image"
audio = "audio"
tool_call = "tool_call"
tool_response = "tool_response"
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # nosec B105
eos_token = "eos_token" # nosec B105
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}

View File

@@ -1,17 +1,23 @@
"""
This module contains a function that builds a transform that takes a row from the
dataset and converts it to a Chat.
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping
from typing import Any, Mapping, Union
def chat_message_transform_builder(
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: str | list[str] | None = None, # commonly "content"
message_field_training: str | list[str] | None = None, # commonly "weight"
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -33,12 +39,6 @@ def chat_message_transform_builder(
A function that takes a list of conversations and returns a list of messages.
"""
if message_field_training is None:
message_field_training = ["train", "weight"]
if message_field_content is None:
message_field_content = ["value", "text", "content"]
if message_field_role is None:
message_field_role = ["role", "from"]
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)

View File

@@ -1,5 +1,6 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer

View File

@@ -1,5 +1,7 @@
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
import os
@@ -42,20 +44,12 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
"max": torch.max,
"sum": torch.sum,
}
class AxolotlTrainer(
PackingMixin,
@@ -71,15 +65,6 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -95,10 +80,9 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -301,9 +285,9 @@ class AxolotlTrainer(
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
else:
self._eval_dataloaders = {dataloader_key: dataloader}
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
# fmt: on
return self.accelerator.prepare(dataloader)
@@ -345,17 +329,6 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -371,11 +344,6 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
@override
def evaluate(self, *args, **kwargs):
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}
@@ -475,7 +443,7 @@ class AxolotlTrainer(
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
@@ -556,10 +524,15 @@ class AxolotlTrainer(
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state(reset_partial_state=True)
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -567,6 +540,7 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
@@ -599,61 +573,29 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
reduction_type = metric_data["reduction"]
fn = REDUCTION_FNS.get(reduction_type)
if fn is None:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(fn(values).item(), 4)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if is_main_process():
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
def store_metrics(
self,
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items():
if isinstance(value, tuple):
value, _reduction = value # type: ignore[assignment]
else:
value, _reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(value)
self._stored_metrics[train_eval][key]["reduction"] = _reduction
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
@@ -720,11 +662,6 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -101,11 +101,11 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type]
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo"
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -128,7 +128,9 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -149,7 +151,7 @@ class GRPOStrategy:
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs):
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None

View File

@@ -1,5 +1,7 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
@@ -50,6 +52,7 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
@@ -250,7 +253,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
@@ -263,7 +266,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns(
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
@@ -305,10 +308,10 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
@@ -330,9 +333,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[
group_leader_rank * len(prompts_text) : (
group_leader_rank + 1
)
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
* len(prompts_text) : self.num_generations
]
@@ -483,7 +485,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text, strict=False):
for prompt, completion in zip(prompts, completions_text):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
@@ -501,7 +503,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.reward_funcs,
self.reward_processing_classes,
self.reward_func_names,
strict=False,
)
):
with profiling_context(self, reward_func_name):
@@ -510,17 +511,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c}
for p, c in zip(prompts, completions, strict=False)
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [
p + c for p, c in zip(prompts, completions, strict=False)
]
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
@@ -566,8 +564,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward.",
stacklevel=2,
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the

View File

@@ -5,6 +5,7 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
@@ -14,8 +15,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

View File

@@ -1,5 +1,6 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin

View File

@@ -92,7 +92,7 @@ def get_lora_act_offloading_ctx_manager(
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,

View File

@@ -26,6 +26,7 @@ class DistributedParallelMixin(Trainer):
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"

View File

@@ -70,11 +70,11 @@ class OptimizerMixin(Trainer):
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"]
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
@@ -143,7 +143,7 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer(
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
@@ -185,15 +185,17 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped / 2**20}M params")
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped / 2**20}M params")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer

View File

@@ -46,7 +46,7 @@ class SchedulerMixin(Trainer):
)
# fmt: off
if self.lr_scheduler is None: # type: ignore
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
@@ -90,7 +90,7 @@ class SchedulerMixin(Trainer):
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -98,7 +98,7 @@ class SchedulerMixin(Trainer):
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -107,7 +107,7 @@ class SchedulerMixin(Trainer):
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr(
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -133,7 +133,7 @@ class SchedulerMixin(Trainer):
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler(
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,

View File

@@ -14,6 +14,7 @@ class AxolotlTrainingMixins:
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
@@ -49,12 +50,6 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -1,17 +1,18 @@
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
"""Module containing Dataset functionality"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -25,7 +26,7 @@ class TokenizedPromptDataset(Dataset):
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
def __init__(
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
@@ -85,3 +86,133 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -79,7 +79,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps

View File

@@ -76,7 +76,7 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: dict):
def register(self, cfg: dict): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
@@ -104,13 +104,14 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg: DictDefault):
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions before the model is loaded.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
@@ -118,6 +119,7 @@ class BasePlugin:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
@@ -126,6 +128,7 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
@@ -134,6 +137,7 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
@@ -142,7 +146,8 @@ class BasePlugin:
model: The loaded model.
"""
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
Args:
@@ -152,6 +157,7 @@ class BasePlugin:
The first non-`None` trainer class returned by a plugin.
"""
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
@@ -160,7 +166,7 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault):
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
@@ -171,7 +177,9 @@ class BasePlugin:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
@@ -183,6 +191,7 @@ class BasePlugin:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -194,6 +203,7 @@ class BasePlugin:
The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self,
cfg: DictDefault,
@@ -213,6 +223,7 @@ class BasePlugin:
The created learning rate scheduler.
"""
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
@@ -227,6 +238,7 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
@@ -242,6 +254,7 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
@@ -250,7 +263,7 @@ class BasePlugin:
model: The loaded model.
"""
def post_train_unload(self, cfg: DictDefault):
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions after training is complete and the model is unloaded.
Args:
@@ -298,7 +311,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager:
class PluginManager: # pylint: disable=too-many-public-methods
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.

View File

@@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
def merge_input_args():
@@ -50,9 +50,15 @@ def merge_input_args():
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
exec(dynamic_input, globals(), namespace) # nosec B102
AxolotlInputConfig = namespace["AxolotlInputConfig"]
AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"]
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, globals(), namespace
)
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
@@ -68,7 +74,7 @@ def merge_training_args() -> Type:
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
@@ -87,7 +93,11 @@ def merge_training_args() -> Type:
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102
AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"]
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
```
## Usage
@@ -34,7 +34,6 @@ plugins:
- arcee
- cohere
- cohere2
- deepseek_v3
- gemma
- gemma2
- gemma3
@@ -43,7 +42,6 @@ plugins:
- gemma3n_text
- glm
- glm4
- glm4_moe
- gpt_oss
- granite
- granitemoe
@@ -66,7 +64,6 @@ plugins:
- qwen3
- qwen3_moe
- smollm3
- seed_oss
- voxtral
## Citation

View File

@@ -18,7 +18,6 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
from functools import partial
@@ -29,13 +28,13 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs as CutCrossEntropyArgs
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
)
@@ -107,7 +106,9 @@ class CutCrossEntropyPlugin(BasePlugin):
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(maybe_model, patch_options, model_type: str):
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -120,10 +121,12 @@ class CutCrossEntropyPlugin(BasePlugin):
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -15,7 +15,6 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator

View File

@@ -1,154 +0,0 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using an approach inspired by
LLaDA (Large Language Diffusion Models) within Axolotl.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to attend to the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. See our
[installation docs](https://docs.axolotl.ai/docs/installation.html).
## Quickstart
Train with an example config (Llama3.2 1B):
- Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
- SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
### Basic Configuration
You can also modify your existing configs to enable / customize diffusion training.
Add the following to your Axolotl config:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
```
And, configure the nested `diffusion` block (defaults shown):
```yaml
diffusion:
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
# Mask token (training auto-adds if missing, avoid pad/eos)
mask_token_str: "<|diffusion_mask|>"
# Or use an existing special token id (e.g., 128002 for Llama-3.x)
# mask_token_id: 128002
# Sample generation during training (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
```
## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
## How It Works
### Random Masking
During training, tokens are randomly masked:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Sample Generation
When `diffusion.generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Inference
Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
you trained with and run:
```
axolotl inference path/to/your-config.yaml
```
Optionally, pass `--gradio` to use a simple web interface.
Interactive controls (prefix the prompt with commands):
- `:complete N` → completion mode with N new masked tokens appended (default 64)
- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
Example session:
```
================================================================================
Commands:
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R (0.01.0)
================================================================================
Give me an instruction (Ctrl + D to submit):
:mask 0.4 The quick brown fox jumps over the lazy dog
Masked (40.0%):
The [MASK] brown [MASK] jumps over the [MASK] dog
Generated:
The quick brown fox jumps over the loud dog
```
## Metrics and Monitoring
The plugin adds (or modifies) several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Limitations
- No flash attention support
- No RL training support
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)
- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)

View File

@@ -1,19 +0,0 @@
"""Diffusion LM training plugin init."""
from .args import DiffusionArgs, DiffusionConfig
from .callbacks import DiffusionGenerationCallback
from .generation import generate
from .plugin import DiffusionPlugin
from .trainer import DiffusionTrainer
from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionTrainer",
"generate",
"resolve_mask_token_id",
"create_bidirectional_attention_mask",
"DiffusionGenerationCallback",
"DiffusionConfig",
]

View File

@@ -1,95 +0,0 @@
"""Config args for diffusion LM training (nested under `diffusion:`)."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
class DiffusionConfig(BaseModel):
"""Nested diffusion configuration available under the `diffusion` key."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int | None = Field(
default=None,
description=(
"Token ID to use for masking. Unset by default; can use one of the "
"tokenizer's special tokens here."
),
)
mask_token_str: str | None = Field(
default=None,
description=(
"Token string to use as a mask. If `mask_token_id` is invalid or unset, "
"this token will be ensured to exist as an additional special token and "
"used. If absent, a default '<|diffusion_mask|>' will be added."
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)
@model_validator(mode="after")
def _validate_mask_ratios(self) -> "DiffusionConfig":
if self.min_mask_ratio > self.max_mask_ratio:
raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
return self
class DiffusionArgs(BaseModel):
"""Plugin entry that exposes the nested `diffusion` block to the core config."""
diffusion: DiffusionConfig = Field(
default_factory=DiffusionConfig,
description="Diffusion training configuration. Only nested block is supported.",
)

View File

@@ -1,174 +0,0 @@
"""Callbacks for diffusion training."""
import logging
import sys
import wandb
from colorama import Fore, Style
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from .generation import generate_samples
# Simpler logger for more readable sample generation
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
if (
state.global_step > 0
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
):
if not self.trainer.state.is_world_process_zero:
return
# Use eval dataloader if available, otherwise use train dataloader
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
except Exception:
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
# Generate samples
diffusion_cfg = self.trainer.cfg.diffusion
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=diffusion_cfg.num_generation_samples,
max_length=diffusion_cfg.generation_max_length,
num_diffusion_steps=diffusion_cfg.generation_steps,
temperature=diffusion_cfg.generation_temperature,
mask_token_id=diffusion_cfg.mask_token_id,
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
logger.info("=" * 60)
logger.info("GENERATED SAMPLES")
logger.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
logger.info(f"\nSample {i}:")
logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
logger.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
try:
gen_ids = sample_data.get("generated_ids")
orig_ids = sample_data.get("orig_ids")
masked_positions = set(sample_data.get("masked_positions") or [])
if isinstance(gen_ids, list) and isinstance(orig_ids, list):
styles: list[str] = []
for i, tid in enumerate(gen_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green")
elif i < len(orig_ids):
styles.append("red")
else:
styles.append("normal")
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
spans: list[tuple[str, int, int]] = []
if gen_ids:
cur = styles[0]
start = 0
for i in range(1, len(gen_ids)):
s = styles[i]
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(gen_ids)))
parts = []
for style_name, a, b in spans:
chunk_text = self.trainer.processing_class.decode(
gen_ids[a:b], skip_special_tokens=False
)
if style_name == "green":
parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
parts.append(chunk_text)
logger.info("\tGenerated:\n%s", "".join(parts))
else:
logger.info(f"\tGenerated: {generated}")
except Exception:
logger.info(f"\tGenerated: {generated}")
logger.info("=" * 60)
if self.trainer.cfg.use_wandb:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -1,409 +0,0 @@
"""Sample generation utilities for diffusion training."""
import re
from typing import Any, List, Literal, Optional
import torch
from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
LOG.warning("No validation dataloader provided, cannot generate samples")
return []
unwrapped_model = model.module if hasattr(model, "module") else model
training = unwrapped_model.training
unwrapped_model.eval()
# Resolve device robustly (some modules don't expose `.device`)
device = getattr(unwrapped_model, "device", None)
if device is None:
try:
device = next(unwrapped_model.parameters()).device
except StopIteration:
device = torch.device("cpu")
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, device
)
LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for sample in sampled_sequences:
if isinstance(sample, dict):
original_sequence = sample.get("input_ids")
labels_seq = sample.get("labels")
attn_seq = sample.get("attention_mask")
else:
original_sequence = sample
labels_seq = None
attn_seq = None
generation_result = generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
mode=mode,
completion_tokens=completion_tokens,
target_mask_ratio=target_mask_ratio,
labels=labels_seq,
attention_mask=attn_seq,
)
generations.append(generation_result)
# Restore prior training state
if training:
unwrapped_model.train()
else:
unwrapped_model.eval()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[Any]:
"""Sample sequences from validation dataloader."""
sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
sample_count = 0
# Skip a random number of batches (we could be more clever about this)
skip_batches = torch.randint(0, 10, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
labels = batch.get("labels")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
if seq_len < 10:
continue
# Determine truncation length
max_total = min(seq_len, max_length)
if labels is not None:
labels_i = labels[i][:seq_len]
answer_mask = labels_i != -100
if not answer_mask.any():
# No answer tokens; skip for SFT masking
continue
first_ans_idx = int(
torch.nonzero(answer_mask, as_tuple=False)[0].item()
)
prompt_len = first_ans_idx
if prompt_len >= max_total:
# Prompt alone reaches cap; cannot include any answer
continue
remaining_answer = int(answer_mask[prompt_len:].sum().item())
allowed_answer = max_total - prompt_len
take_answer = min(remaining_answer, allowed_answer)
if take_answer <= 0:
continue
actual_length = prompt_len + take_answer
else:
actual_length = max_total
# Extract the (possibly truncated) sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
attn_seq = (
attention_mask[i][:actual_length].unsqueeze(0).to(device)
if attention_mask is not None
else None
)
if labels is not None:
labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(
{
"input_ids": sequence,
"labels": labels_seq,
"attention_mask": attn_seq,
}
)
else:
if attn_seq is not None:
sampled_sequences.append(
{"input_ids": sequence, "attention_mask": attn_seq}
)
else:
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
*,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Build masked sequence
if (
labels is not None
and labels.numel() > 0
and (labels == -100).any()
and (labels != -100).any()
):
# SFT case: completely mask all answer tokens (labels != -100)
total_tokens = original_sequence.size(1)
masked_indices = (labels != -100).to(dtype=torch.bool)
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
masked_tokens = int(masked_indices.sum().item())
mask_ratio = masked_tokens / max(int(total_tokens), 1)
elif mode == "completion" and completion_tokens > 0:
# Append mask tokens to the right for completion
total_tokens = original_sequence.size(1) + int(completion_tokens)
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, -int(completion_tokens) :] = True
append = torch.full(
(1, int(completion_tokens)), mask_token_id, device=original_sequence.device
)
masked_sequence = torch.cat([original_sequence, append], dim=1)
masked_tokens = int(completion_tokens)
mask_ratio = masked_tokens / total_tokens
else:
# Apply random masking with optional fixed ratio
total_tokens = original_sequence.size(1)
if target_mask_ratio is None:
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = (
torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
)
target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
attention_mask = create_bidirectional_attention_mask(
sequence, attention_mask, sample_packing=attention_mask is not None
)
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model,
sequence,
step,
num_diffusion_steps,
temperature,
mask_token_id,
attention_mask,
)
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
# Collect diagnostic info
final_ids = sequence[0].detach().cpu().tolist()
orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
if masked_indices is not None:
masked_positions = (
torch.where(masked_indices[0])[0].detach().cpu().tolist()
if masked_indices.ndim == 2
else []
)
else:
masked_positions = []
result = {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"generated_ids": final_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids_for_render,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
return result
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Remove literal special token strings
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
# Normalize whitespace but preserve newlines
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
cleaned = re.sub(r"[ \t]+", " ", cleaned)
cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create or use provided attention mask
if attention_mask is None:
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

Some files were not shown because too many files have changed in this diff Show More