From 9640338d37d0398cd3c0c0ab6e629b6dd9dcd5d3 Mon Sep 17 00:00:00 2001 From: salman Date: Tue, 9 Sep 2025 15:50:21 +0100 Subject: [PATCH 1/8] Default `include_tkps` to true (#3134) * default true * force e2e * causal trainer only * fix eval loggin [skip-ci] * revert setup.py * force tests * guarding * guarding * fix test case * use evaluate [skip-e2e] * use evaluate [skip-e2e] * kick off ci * fixing * reverting --- src/axolotl/core/builders/base.py | 7 ------- src/axolotl/core/builders/causal.py | 7 +++++++ src/axolotl/core/trainers/base.py | 4 ++-- src/axolotl/utils/callbacks/tokens_per_second.py | 16 +++++++++------- src/axolotl/utils/schemas/config.py | 4 ++-- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index bee291fa2..1ec818004 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -36,7 +36,6 @@ from axolotl.utils.callbacks import ( SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.schemas.enums import CustomSupportedOptimizers @@ -145,12 +144,6 @@ class TrainerBuilderBase(abc.ABC): profiler_steps_start=self.cfg.profiler_steps_start, ) ) - if self.cfg.include_tkps: - callbacks.append( - TokensPerSecondCallback( - self.cfg.tensor_parallel_size, self.cfg.context_parallel_size - ) - ) return callbacks diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 057d0ab5c..ee6383d47 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -39,6 +39,7 @@ from axolotl.utils.collators import ( MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger @@ -71,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.qat: callbacks.append(QATCallback(self.cfg.qat)) + if self.cfg.include_tkps: + callbacks.append( + TokensPerSecondCallback( + self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + ) + ) return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 06eef445b..d7555261f 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -342,10 +342,10 @@ class AxolotlTrainer( inputs_key = "labels" if "labels" in inputs else "input_ids" if hasattr(self.state, "num_tokens"): self.state.num_tokens = ( - self.state.num_tokens + (inputs[inputs_key] != -100).sum() + self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() ) else: - self.state.num_tokens = (inputs[inputs_key] != -100).sum() + self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() if self.args.orpo_alpha: return self.orpo_compute_loss( diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 85bcd5041..ead129240 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -43,11 +43,12 @@ class TokensPerSecondCallback(TrainerCallback): control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument - step_time = time.perf_counter() - self.start_time - num_tokens_per_device = state.num_tokens.clone() - # non data parallel groups have duplicated tokens, so we avoid double-counting - num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size - state.last_tokens_per_second = num_tokens_per_device / step_time + if hasattr(state, "num_tokens"): + step_time = time.perf_counter() - self.start_time + num_tokens_per_device = state.num_tokens.clone() + # non data parallel groups have duplicated tokens, so we avoid double-counting + num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size + state.last_tokens_per_second = num_tokens_per_device / step_time def on_log( self, @@ -58,5 +59,6 @@ class TokensPerSecondCallback(TrainerCallback): **kwargs, ): # pylint: disable=unused-argument # after logging, clear the running metrics - state.last_tokens_per_second.zero_() - state.num_tokens = 0 + if hasattr(state, "last_tokens_per_second"): + state.last_tokens_per_second.zero_() + state.num_tokens = torch.zeros(1) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 32d7b68e7..e4c1fdf29 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -855,9 +855,9 @@ class AxolotlInputConfig( }, ) include_tkps: bool | None = Field( - default=None, + default=True, json_schema_extra={ - "description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens." + "description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens." }, ) neftune_noise_alpha: float | None = Field( From 79103b01ca1c914103d888d88fdb903e29840d4f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 10 Sep 2025 09:01:02 +0700 Subject: [PATCH 2/8] Feat: add seedoss (#3104) [skip ci] * feat: add seedoss cce * feat: add seedoss config and docs * fix: shouldn't have target modules with target linear * feat: add vram numbers * fix: hf link * fix: name * fix: support multipack seedoss * fix: merge error * feat: update seedoss instructions for transformers release --- examples/seed-oss/README.md | 54 ++++++++++++++++++ examples/seed-oss/seed-oss-36b-qlora.yaml | 56 +++++++++++++++++++ .../integrations/cut_cross_entropy/README.md | 3 + src/axolotl/monkeypatch/multipack.py | 1 + 4 files changed, 114 insertions(+) create mode 100644 examples/seed-oss/README.md create mode 100644 examples/seed-oss/seed-oss-36b-qlora.yaml diff --git a/examples/seed-oss/README.md b/examples/seed-oss/README.md new file mode 100644 index 000000000..5610c1316 --- /dev/null +++ b/examples/seed-oss/README.md @@ -0,0 +1,54 @@ +# Finetune ByteDance's Seed-OSS with Axolotl + +[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install Cut Cross Entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml +``` + +This config uses about 27.7 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [ByteDance Seed Website](https://seed.bytedance.com/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/seed-oss/seed-oss-36b-qlora.yaml b/examples/seed-oss/seed-oss-36b-qlora.yaml new file mode 100644 index 000000000..00e7cf3eb --- /dev/null +++ b/examples/seed-oss/seed-oss-36b-qlora.yaml @@ -0,0 +1,56 @@ +base_model: ByteDance-Seed/Seed-OSS-36B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index a64bdd054..393412f64 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -34,6 +34,7 @@ plugins: - arcee - cohere - cohere2 +- deepseek_v3 - gemma - gemma2 - gemma3 @@ -42,6 +43,7 @@ plugins: - gemma3n_text - glm - glm4 +- glm4_moe - gpt_oss - granite - granitemoe @@ -64,6 +66,7 @@ plugins: - qwen3 - qwen3_moe - smollm3 +- seed_oss - voxtral ## Citation diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index e4f9ca2be..cbc546877 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -38,6 +38,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "smollm3", "gpt_oss", "arcee", + "seed_oss", ] From b71482cec5beb118efdd2bc466589e9a6eb64e77 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 10 Sep 2025 09:03:30 +0700 Subject: [PATCH 3/8] Feat: add hunyuan v1 (#3016) * feat: add hunyuan cce support * feat: update cce docs * feat: add multipack support for granite and hunyuan * feat: add hunyuan docs and example config * feat: update readme instructions to include CCE installation * fix: chat template log appearing despite tokenizer already having template * feat: add vram usage * fix: remove duplicate cce install * fix: use latest commit of PR in case rebased/pushed * Revert "fix: use latest commit of PR in case rebased/pushed" This reverts commit 8b60aa00de5511c09a6cad64ae1cf476e6a5eddc. * feat: update doc as upstream merged --- examples/devstral/README.md | 8 +- examples/hunyuan/README.md | 85 ++++++++++++++++++++ examples/hunyuan/hunyuan-v1-dense-qlora.yaml | 64 +++++++++++++++ examples/magistral/README.md | 8 +- examples/voxtral/README.md | 3 + src/axolotl/loaders/tokenizer.py | 2 +- src/axolotl/monkeypatch/multipack.py | 4 + 7 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 examples/hunyuan/README.md create mode 100644 examples/hunyuan/hunyuan-v1-dense-qlora.yaml diff --git a/examples/devstral/README.md b/examples/devstral/README.md index b53635a8f..ae0860662 100644 --- a/examples/devstral/README.md +++ b/examples/devstral/README.md @@ -20,7 +20,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' ``` -2. Run the finetuning example: +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage + +```bash +python scripts/cutcrossentropy_install.py | sh +``` + +3. Run the finetuning example: ```bash axolotl train examples/devstral/devstral-small-qlora.yml diff --git a/examples/hunyuan/README.md b/examples/hunyuan/README.md new file mode 100644 index 000000000..96c6bbcfa --- /dev/null +++ b/examples/hunyuan/README.md @@ -0,0 +1,85 @@ +# Finetune HunYuan with Axolotl + +Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml +``` + +This config uses about 4.7 GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### Dataset + +HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern. + +```python +# fast think pattern +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "/no_think What color is the sun?" }, + {"role": "assistant", "content": "\n\n\n\nThe sun is yellow.\n"} +] + +# slow think pattern +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "/no_think What color is the sun?" }, + {"role": "assistant", "content": "\nThe user is asking about the color of the sun. I need to ...\n\n\nThe sun is yellow.\n"} +] +``` + +### TIPS + +- For inference, the official Tencent team recommends + +```json + +{ + "do_sample": true, + "top_k": 20, + "top_p": 0.8, + "repetition_penalty": 1.05, + "temperature": 0.7 +} + +``` + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Tencent HunYuan Blog](https://hunyuan.tencent.com/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/hunyuan/hunyuan-v1-dense-qlora.yaml b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml new file mode 100644 index 000000000..a94345a61 --- /dev/null +++ b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml @@ -0,0 +1,64 @@ +base_model: tencent/Hunyuan-0.5B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/README.md b/examples/magistral/README.md index 48ce712da..f4f278208 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -18,7 +18,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' ``` -2. Run the finetuning example: +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage + +```bash +python scripts/cutcrossentropy_install.py | sh +``` + +3. Run the finetuning example: ```bash axolotl train examples/magistral/magistral-small-qlora.yaml diff --git a/examples/voxtral/README.md b/examples/voxtral/README.md index f31e9cfd0..984af4ddb 100644 --- a/examples/voxtral/README.md +++ b/examples/voxtral/README.md @@ -22,6 +22,9 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' # audio pip3 install librosa==0.11.0 pip3 install 'mistral_common[audio]==1.8.3' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh ``` 3. Run the finetuning example: diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index dcc255938..37b66ac83 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -296,7 +296,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: ) tokenizer.chat_template = chat_template_string - else: + elif getattr(tokenizer, "chat_template", None) is None: LOG.info( "No Chat template selected. Consider adding a chat template for easier inference." ) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index cbc546877..a32430d9f 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -36,6 +36,10 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "glm", "glm4", "smollm3", + "granite", + "granitemoe", + "hunyuan_v1_dense", + "hunyuan_v1_moe", "gpt_oss", "arcee", "seed_oss", From 1b53c49e1a8408ff209ae72a480681d18f7f8c81 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 10 Sep 2025 20:27:00 -0400 Subject: [PATCH 4/8] text diffusion training plugin (#3067) * diffusion training plugin * cleanup * nits * fixes + improvements * add back in reinit_weights (clobbered?); masking / pretrain fixes * nits * cleanup; tests draft * sample generation, tests fixes * fixes * nits * add inference support; add auto-mask token support * nits * nits * progress * simplify logging * lint * prefix args with diffusion_ * coderabbito * tests fix * nit * nits * cleanup + nits * nits * fix SFT sample gen * fixes * fix * comments * comments * lint * reward model lora fix * cleanup; fix pretraining_dataset case * gradio inference * update cfgs * update cfgs * train, generation parity, cleanup * fix * simplify * test * test fix --- .pre-commit-config.yaml | 2 +- .../colab-axolotl-example.ipynb | 2 +- examples/llama-3/diffusion/pretrain-1b.yaml | 56 +++ examples/llama-3/diffusion/sft-1b.yaml | 59 +++ src/axolotl/cli/inference.py | 63 ++- src/axolotl/cli/utils/diffusion.py | 375 ++++++++++++++++ src/axolotl/core/builders/causal.py | 15 +- src/axolotl/core/trainers/base.py | 46 +- src/axolotl/integrations/base.py | 2 +- src/axolotl/integrations/config.py | 2 +- src/axolotl/integrations/diffusion/README.md | 154 +++++++ .../integrations/diffusion/__init__.py | 19 + src/axolotl/integrations/diffusion/args.py | 95 ++++ .../integrations/diffusion/callbacks.py | 174 ++++++++ .../integrations/diffusion/generation.py | 409 ++++++++++++++++++ src/axolotl/integrations/diffusion/plugin.py | 41 ++ src/axolotl/integrations/diffusion/trainer.py | 301 +++++++++++++ src/axolotl/integrations/diffusion/utils.py | 159 +++++++ src/axolotl/loaders/adapter.py | 12 +- src/axolotl/loaders/model.py | 118 +++-- src/axolotl/loaders/patch_manager.py | 5 +- src/axolotl/monkeypatch/accelerate/fsdp2.py | 8 +- .../monkeypatch/attention/flex_attn.py | 3 +- src/axolotl/monkeypatch/deepspeed_utils.py | 1 + src/axolotl/utils/config/__init__.py | 2 +- src/axolotl/utils/data/__init__.py | 8 +- src/axolotl/utils/data/sft.py | 2 +- src/axolotl/utils/environment.py | 2 - src/axolotl/utils/schemas/config.py | 6 + src/axolotl/utils/schemas/validation.py | 1 - tests/e2e/test_diffusion.py | 139 ++++++ tests/integrations/test_diffusion.py | 274 ++++++++++++ tests/integrations/test_diffusion_callback.py | 92 ++++ tests/test_streaming.py | 4 +- 34 files changed, 2550 insertions(+), 101 deletions(-) create mode 100644 examples/llama-3/diffusion/pretrain-1b.yaml create mode 100644 examples/llama-3/diffusion/sft-1b.yaml create mode 100644 src/axolotl/cli/utils/diffusion.py create mode 100644 src/axolotl/integrations/diffusion/README.md create mode 100644 src/axolotl/integrations/diffusion/__init__.py create mode 100644 src/axolotl/integrations/diffusion/args.py create mode 100644 src/axolotl/integrations/diffusion/callbacks.py create mode 100644 src/axolotl/integrations/diffusion/generation.py create mode 100644 src/axolotl/integrations/diffusion/plugin.py create mode 100644 src/axolotl/integrations/diffusion/trainer.py create mode 100644 src/axolotl/integrations/diffusion/utils.py create mode 100644 tests/e2e/test_diffusion.py create mode 100644 tests/integrations/test_diffusion.py create mode 100644 tests/integrations/test_diffusion_callback.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92ddc7f41..9c80898ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: rev: v0.12.12 hooks: - id: ruff - args: [--fix] + args: [--fix, --select, I] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index b780a1c48..0e6ba984e 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -176,8 +176,8 @@ } ], "source": [ - "from axolotl.utils.dict import DictDefault\n", "from axolotl.cli.config import load_cfg\n", + "from axolotl.utils.dict import DictDefault\n", "\n", "# Axolotl provides full control and transparency over model and training configuration\n", "config = DictDefault(\n", diff --git a/examples/llama-3/diffusion/pretrain-1b.yaml b/examples/llama-3/diffusion/pretrain-1b.yaml new file mode 100644 index 000000000..8d05e4c60 --- /dev/null +++ b/examples/llama-3/diffusion/pretrain-1b.yaml @@ -0,0 +1,56 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +pretraining_dataset: + - path: wikitext + name: wikitext-103-raw-v1 + type: completion + field: text + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.15 + max_mask_ratio: 0.85 + num_diffusion_steps: 128 + eps: 5e-4 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true + +gradient_accumulation_steps: 8 +micro_batch_size: 4 +max_steps: 10000 +warmup_ratio: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 3e-4 +sdp_attention: true + +bf16: auto +tf32: true + +logging_steps: 1 +save_strategy: steps +save_steps: 1000 + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/diffusion/sft-1b.yaml b/examples/llama-3/diffusion/sft-1b.yaml new file mode 100644 index 000000000..f3b29a809 --- /dev/null +++ b/examples/llama-3/diffusion/sft-1b.yaml @@ -0,0 +1,59 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +val_set_size: 0.05 + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true +eval_sample_packing: true + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 1 +warmup_steps: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 1e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +resume_from_checkpoint: +sdp_attention: true + +logging_steps: 1 +save_strategy: best +eval_strategy: epoch + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index debe57167..30d407713 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -14,6 +14,13 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.cli.args import InferenceCliArgs from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.cli.utils.diffusion import ( + diffusion_inference, + launch_diffusion_gradio_ui, + render_html, + run_diffusion, +) +from axolotl.integrations.base import PluginManager from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -29,6 +36,7 @@ def get_multi_line_input() -> str: Possibly multi-line, possibly empty stdin input as a string. """ print("Give me an instruction (Ctrl + D to submit): ") + print("=" * 80) instruction = "" for line in sys.stdin: @@ -43,9 +51,9 @@ def do_inference( cli_args: InferenceCliArgs, ): """ - Runs inference on the command line in a loop. User input is accepted, a chat template - is (optionally) applied, and the model specified in the `axolotl` config is used to - generate completions according to a default generation config. + Runs inference on the command line in a loop. User input is accepted, a chat + template is (optionally) applied, and the model specified in the `axolotl` config is + used to generate completions according to a default generation config. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -64,16 +72,28 @@ def do_inference( chat_template_str = get_chat_template_from_config( cfg, ds_cfg=None, tokenizer=tokenizer ) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + print("=" * 80) + print("Commands:") + print(":complete N -> completion mode with N tokens (default 64)") + print(":mask R -> random masking with ratio R (0.0–1.0)") + while True: print("=" * 80) - # support for multiline inputs instruction = get_multi_line_input() if not instruction: return @@ -103,9 +123,19 @@ def do_inference( else: batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - print("=" * 40) + print("=" * 80) model.eval() with torch.no_grad(): + if is_diffusion: + diffusion_inference( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + ) + continue + generation_config = GenerationConfig( repetition_penalty=1.1, max_new_tokens=1024, @@ -128,7 +158,7 @@ def do_inference( generation_config=generation_config, streamer=streamer, ) - print("=" * 40) + print("=" * 80) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) @@ -161,13 +191,30 @@ def do_inference_gradio( chat_template_str = get_chat_template_from_config( cfg, ds_cfg=None, tokenizer=tokenizer ) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + launch_diffusion_gradio_ui( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompter_module=prompter_module, + chat_template_str=chat_template_str, + ) + return + def generate(instruction): if not instruction: return diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py new file mode 100644 index 000000000..f83d9077b --- /dev/null +++ b/src/axolotl/cli/utils/diffusion.py @@ -0,0 +1,375 @@ +"""Helpers for diffusion-mode inference in CLI and Gradio.""" + +from __future__ import annotations + +import gradio as gr +import torch +from colorama import Fore, Style + +from axolotl.integrations.diffusion import generate, resolve_mask_token_id +from axolotl.utils.dict import DictDefault + + +def diffusion_inference( + model, + tokenizer, + cfg, + prompt: str, + chat_template_str: str | None = None, +): + """Diffusion inference helper method.""" + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt) + + if cleaned: + prompt = cleaned + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=mode, + target_mask_ratio=target_mask_ratio, + completion_tokens=completion_tokens, + ) + masked_text = info["masked_text"] + mask_ratio = info["mask_ratio"] + generated_ids = info["generated_ids"] + masked_positions = info["masked_positions"] + orig_ids = info["orig_ids"] + + # Display with masked preview and colored diff + if masked_text is not None and mask_ratio is not None: + print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n") + if generated_ids is not None: + # Compute per-token style + styles: list[str] = [] + for i, tid in enumerate(generated_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") # correct fill + elif i < len(orig_ids): + styles.append("red") # incorrect fill + else: + styles.append("normal") # appended + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + # Group contiguous spans by style + styled_spans: list[tuple[str, int, int]] = [] + if generated_ids: + current_style = styles[0] + start = 0 + for i in range(1, len(generated_ids)): + s = styles[i] + if s != current_style: + styled_spans.append((current_style, start, i)) + current_style, start = s, i + styled_spans.append((current_style, start, len(generated_ids))) + + out_parts = [] + for style_name, a, b in styled_spans: + chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + out_parts.append(chunk_text) + print("Generated:\n" + "".join(out_parts)) + else: + print("Generated:\n(no output)") + + +def _parse_commands(text: str): + """ + Parse leading diffusion commands. + + Supported at start of input (can be chained): + :complete N -> completion mode with N tokens (default 64) + :mask R -> random masking with ratio R in [0, 1] + """ + tokens = text.strip().split() + i = 0 + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + consumed = 0 + while i < len(tokens) and tokens[i].startswith(":"): + cmd = tokens[i] + i += 1 + consumed = i + if cmd == ":complete": + mode = "completion" + if i < len(tokens): + try: + completion_tokens = int(tokens[i]) + i += 1 + consumed = i + except Exception: + completion_tokens = 64 + else: + completion_tokens = 64 + elif cmd == ":mask": + mode = "random" + if i < len(tokens): + try: + target_mask_ratio = float(tokens[i]) + i += 1 + consumed = i + except Exception: + target_mask_ratio = None + else: + i -= 1 + consumed = i + break + + cleaned = " ".join(tokens[consumed:]) + + return mode, completion_tokens, target_mask_ratio, cleaned + + +def run_diffusion( + *, + model, + tokenizer, + cfg: DictDefault, + prompt: str, + chat_template_str: str | None, + mode: str = "random", + target_mask_ratio: float | None = None, + completion_tokens: int = 0, +): + """Run a single diffusion generation and return a structured result dict.""" + if chat_template_str: + batch = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False) + + seq = batch["input_ids"].to(cfg.device) + gen_mode = "completion" if mode == "completion" else "random" + comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0 + + result = generate( + model, + tokenizer, + original_sequence=seq[:1], + num_diffusion_steps=cfg.diffusion.num_diffusion_steps, + temperature=cfg.diffusion.generation_temperature, + mask_token_id=int(mask_token_id), + mode=gen_mode, # type: ignore[arg-type] + completion_tokens=comp_tokens, + target_mask_ratio=target_mask_ratio, + ) + + masked_text = result.get("masked") if isinstance(result, dict) else None + mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None + generated_ids = result.get("generated_ids") if isinstance(result, dict) else None + masked_positions = ( + set(result.get("masked_positions") or []) if isinstance(result, dict) else set() + ) + orig_ids = seq[0].detach().cpu().tolist() + + return { + "masked_text": masked_text, + "mask_ratio": mask_ratio, + "generated_ids": generated_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids, + } + + +def render_html( + *, + generated_ids: list[int] | None, + orig_ids: list[int], + masked_positions: set[int], + tokenizer, +) -> str: + """Render HTML visualizing diffusion outputs.""" + if not generated_ids: + return "
Generated:\n(no output)
" + + def _style_for(i: int, tid: int) -> str: + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + return "green" + if i < len(orig_ids): + return "red" + return "normal" + same = i < len(orig_ids) and tid == orig_ids[i] + return "dim" if same else "normal" + + # Group contiguous spans by style to reduce HTML size + spans: list[tuple[str, int, int]] = [] + if generated_ids: + cur = _style_for(0, generated_ids[0]) + start = 0 + for i in range(1, len(generated_ids)): + s = _style_for(i, generated_ids[i]) + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(generated_ids))) + + html_parts = [] + for style_name, a, b in spans: + txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + html_parts.append(f'{txt}') + elif style_name == "red": + html_parts.append(f'{txt}') + elif style_name == "dim": + html_parts.append(f'{txt}') + else: + html_parts.append(txt) + + legend = ( + '
' + 'correct, ' + 'incorrect, ' + 'unchanged' + "
" + ) + + return ( + legend + + '
Generated:\n'
+        + "".join(html_parts)
+        + "
" + ) + + +def launch_diffusion_gradio_ui( + *, + model, + tokenizer, + cfg: DictDefault, + prompter_module=None, + chat_template_str: str | None = None, +): + """Build and launch a simple Gradio UI for diffusion inference.""" + with gr.Blocks( + title=cfg.get("gradio_title", "Axolotl Diffusion Interface") + ) as demo: + gr.Markdown( + """ + ## Axolotl Diffusion Inference + - Mode "Random" masks tokens at a target ratio and fills them. + - Mode "Completion" appends N masked tokens at the end and fills them. + """ + ) + + with gr.Row(): + mode = gr.Radio( + choices=["random", "completion"], + value="random", + label="Mode", + ) + mask_ratio = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.05, + value=0.4, + label="Mask ratio (random mode)", + interactive=True, + ) + completion_tokens = gr.Number( + value=64, + precision=0, + label="Completion tokens (completion mode)", + interactive=True, + visible=False, + ) + + instruction = gr.Textbox(label="Instruction", lines=6) + run_btn = gr.Button("Generate") + + masked_preview = gr.Textbox(label="Masked preview", lines=6) + html_out = gr.HTML(label="Generated") + + def _toggle_controls(selected_mode: str): + return ( + gr.update(visible=(selected_mode == "random")), + gr.update(visible=(selected_mode == "completion")), + ) + + mode.change( + _toggle_controls, + inputs=[mode], + outputs=[mask_ratio, completion_tokens], + ) + + def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int): + if not instruction_text: + return "", "
Generated:\n(no output)
" + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt( + instruction=instruction_text.strip("\n") + ) + ) + else: + prompt = instruction_text.strip() + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=selected_mode, + target_mask_ratio=mratio if selected_mode == "random" else None, + completion_tokens=int(ctoks) if selected_mode == "completion" else 0, + ) + + masked_text = info.get("masked_text") + mask_ratio_val = info.get("mask_ratio") + generated_ids = info.get("generated_ids") + masked_positions = info.get("masked_positions") or set() + orig_ids = info.get("orig_ids") or [] + + preview = ( + f"Masked ({mask_ratio_val:.1%}):\n{masked_text}" + if masked_text is not None and mask_ratio_val is not None + else "" + ) + html = render_html( + generated_ids=generated_ids, + orig_ids=orig_ids, + masked_positions=masked_positions, + tokenizer=tokenizer, + ) + return preview, html + + run_btn.click( + _gen, + inputs=[instruction, mode, mask_ratio, completion_tokens], + outputs=[masked_preview, html_out], + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index ee6383d47..f7f350e1a 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -7,7 +7,11 @@ from pathlib import Path from typing import Type, Union import transformers -from transformers import DataCollatorWithFlattening, EarlyStoppingCallback +from transformers import ( + DataCollatorWithFlattening, + EarlyStoppingCallback, + Trainer, +) from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.builders.base import TrainerBuilderBase @@ -23,15 +27,16 @@ from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( + LossWatchDogCallback, + SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, colab_inference_post_train_callback, log_prediction_callback_factory, - LossWatchDogCallback, - SaveBetterTransformerModelCallback, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -39,7 +44,6 @@ from axolotl.utils.collators import ( MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) -from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger @@ -391,10 +395,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **data_collator_kwargs, ) sig = inspect.signature(trainer_cls) - if "processing_class" in sig.parameters: + if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): trainer_kwargs["processing_class"] = self.tokenizer elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer + if ( trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d7555261f..3427a0b86 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -49,6 +49,13 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) +REDUCTION_FNS = { + "mean": torch.mean, + "min": torch.min, + "max": torch.max, + "sum": torch.sum, +} + class AxolotlTrainer( PackingMixin, @@ -89,7 +96,9 @@ class AxolotlTrainer( super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self._stored_metrics = defaultdict( + lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) + ) if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -585,9 +594,17 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + + for key, metric_data in self._stored_metrics[train_eval].items(): + values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] + reduction_type = metric_data["reduction"] + + fn = REDUCTION_FNS.get(reduction_type) + if fn is None: + raise NotImplementedError( + "Metric reduction must be one of [mean, min, max, sum]" + ) + logs[key] = round(fn(values).item(), 4) if is_main_process(): # Add memory usage @@ -611,10 +628,27 @@ class AxolotlTrainer( return super().log(logs, start_time) def store_metrics( - self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, + metrics: dict[str, float] | dict[str, tuple[int | float, str]], + train_eval: Literal["train", "eval"] = "train", + reduction: Literal["mean", "min", "max", "sum"] = "mean", ) -> None: + """ + Store metrics with specified reduction type. + + Args: + metrics: Dictionary of metric names to values, or metric names to (value, + reduction_type) tuples. + train_eval: Whether this is for training or evaluation. + """ for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) + if isinstance(value, tuple): + value, _reduction = value # type: ignore[assignment] + else: + value, _reduction = value, reduction + + self._stored_metrics[train_eval][key]["values"].append(value) + self._stored_metrics[train_eval][key]["reduction"] = _reduction def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 8edee18a3..c66bc01c6 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -142,7 +142,7 @@ class BasePlugin: model: The loaded model. """ - def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: """Returns a custom class for the trainer. Args: diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index 2217b2819..8ae8aab39 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase def merge_input_args(): diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md new file mode 100644 index 000000000..c27f33de1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/README.md @@ -0,0 +1,154 @@ +# Diffusion LM Training Plugin for Axolotl + +This plugin enables diffusion language model training using an approach inspired by +LLaDA (Large Language Diffusion Models) within Axolotl. + +## Overview + +LLaDA is a diffusion-based approach to language model training that uses: +- **Random token masking** during training instead of next-token prediction +- **Bidirectional attention** to allow the model to attend to the full context +- **Importance weighting** based on masking probabilities for stable training + +This approach can lead to more robust language models with better understanding of +bidirectional context. + +## Installation + +The plugin is included with Axolotl. See our +[installation docs](https://docs.axolotl.ai/docs/installation.html). + +## Quickstart + +Train with an example config (Llama‑3.2 1B): + - Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml` + - SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml` + +### Basic Configuration + +You can also modify your existing configs to enable / customize diffusion training. + +Add the following to your Axolotl config: + +```yaml +# Enable diffusion LM training plugin +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin +``` + +And, configure the nested `diffusion` block (defaults shown): + +```yaml +diffusion: + noise_schedule: linear # or "cosine" + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + + # Mask token (training auto-adds if missing, avoid pad/eos) + mask_token_str: "<|diffusion_mask|>" + # Or use an existing special token id (e.g., 128002 for Llama-3.x) + # mask_token_id: 128002 + + # Sample generation during training (optional) + generate_samples: true + generation_interval: 100 + num_generation_samples: 3 + generation_steps: 128 + generation_temperature: 0.0 + generation_max_length: 100 +``` + +## Supported Models + +Any models that support 4D attention masks should work out of the box. If not, please +create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a +[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)! + +## How It Works + +### Random Masking +During training, tokens are randomly masked: +- Sample timestep `t` uniformly from [0, 1] +- Calculate masking probability: `p = (1 - eps) * t + eps` +- Randomly mask tokens with probability `p` + +### Diffusion Loss + +Loss is computed only on masked tokens with (optional) importance weighting: + +```python +loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens +``` + +## Sample Generation + +When `diffusion.generate_samples: true`, the plugin generates samples during training: + +``` +Sample 1: + Original (45 tokens): The quick brown fox jumps over the lazy dog... + Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]... + Generated: The quick brown fox jumps over the lazy dog... +``` + +Samples are logged to console and wandb (if enabled). + +## Inference + +Diffusion inference is integrated into the standard Axolotl CLI. Use the same config +you trained with and run: + +``` +axolotl inference path/to/your-config.yaml +``` + +Optionally, pass `--gradio` to use a simple web interface. + +Interactive controls (prefix the prompt with commands): +- `:complete N` → completion mode with N new masked tokens appended (default 64) +- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0] + +Example session: + +``` +================================================================================ +Commands: +:complete N -> completion mode with N tokens (default 64) +:mask R -> random masking with ratio R (0.0–1.0) +================================================================================ +Give me an instruction (Ctrl + D to submit): + +:mask 0.4 The quick brown fox jumps over the lazy dog + +Masked (40.0%): +The [MASK] brown [MASK] jumps over the [MASK] dog + +Generated: +The quick brown fox jumps over the loud dog +``` + +## Metrics and Monitoring + +The plugin adds (or modifies) several metrics to track diffusion training: + +- `train/loss`: Weighted diffusion loss +- `train/accuracy`: Accuracy on masked tokens +- `train/mask_ratio`: Average fraction of tokens masked +- `train/num_masked_tokens`: Number of tokens masked +- `train/avg_p_mask`: Average masking probability +- `train/ce_loss`: Unweighted cross-entropy loss +- `train/importance_weight_avg`: Average importance weight + +## Limitations + +- No flash attention support +- No RL training support + +## References + +- [LLaDA Paper](https://arxiv.org/abs/2404.10406) +- [Axolotl Documentation](https://docs.axolotl.ai/) +- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args) diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py new file mode 100644 index 000000000..9e38cc5c1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/__init__.py @@ -0,0 +1,19 @@ +"""Diffusion LM training plugin init.""" + +from .args import DiffusionArgs, DiffusionConfig +from .callbacks import DiffusionGenerationCallback +from .generation import generate +from .plugin import DiffusionPlugin +from .trainer import DiffusionTrainer +from .utils import create_bidirectional_attention_mask, resolve_mask_token_id + +__all__ = [ + "DiffusionArgs", + "DiffusionPlugin", + "DiffusionTrainer", + "generate", + "resolve_mask_token_id", + "create_bidirectional_attention_mask", + "DiffusionGenerationCallback", + "DiffusionConfig", +] diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py new file mode 100644 index 000000000..4f5bfe499 --- /dev/null +++ b/src/axolotl/integrations/diffusion/args.py @@ -0,0 +1,95 @@ +"""Config args for diffusion LM training (nested under `diffusion:`).""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field, model_validator + + +class DiffusionConfig(BaseModel): + """Nested diffusion configuration available under the `diffusion` key.""" + + # Noise schedule config + noise_schedule: Literal["linear", "cosine"] = Field( + default="linear", description="Type of noise schedule for diffusion training" + ) + min_mask_ratio: float = Field( + default=0.1, + ge=0.0, + le=1.0, + description="Minimum masking ratio for diffusion noise schedule", + ) + max_mask_ratio: float = Field( + default=0.9, + ge=0.0, + le=1.0, + description="Maximum masking ratio for diffusion noise schedule", + ) + num_diffusion_steps: int = Field( + default=128, ge=1, description="Number of diffusion timesteps" + ) + eps: float = Field( + default=1e-3, + ge=0.0, + le=1.0, + description="Epsilon value for minimum masking probability in forward process", + ) + + # Training config + importance_weighting: bool = Field( + default=True, + description="Apply importance weighting to loss based on masking probability", + ) + mask_token_id: int | None = Field( + default=None, + description=( + "Token ID to use for masking. Unset by default; can use one of the " + "tokenizer's special tokens here." + ), + ) + mask_token_str: str | None = Field( + default=None, + description=( + "Token string to use as a mask. If `mask_token_id` is invalid or unset, " + "this token will be ensured to exist as an additional special token and " + "used. If absent, a default '<|diffusion_mask|>' will be added." + ), + ) + + # Sample generation config + generate_samples: bool = Field( + default=True, description="Enable sample generation during training" + ) + generation_interval: int = Field( + default=100, ge=1, description="Generate samples every N steps" + ) + num_generation_samples: int = Field( + default=3, ge=1, description="Number of samples to generate each time" + ) + generation_steps: int = Field( + default=128, ge=1, description="Number of diffusion steps for generation" + ) + generation_temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for generation sampling (0.0 = deterministic)", + ) + generation_max_length: int = Field( + default=100, ge=1, description="Maximum sequence length for generation" + ) + + @model_validator(mode="after") + def _validate_mask_ratios(self) -> "DiffusionConfig": + if self.min_mask_ratio > self.max_mask_ratio: + raise ValueError("min_mask_ratio must be ≤ max_mask_ratio") + return self + + +class DiffusionArgs(BaseModel): + """Plugin entry that exposes the nested `diffusion` block to the core config.""" + + diffusion: DiffusionConfig = Field( + default_factory=DiffusionConfig, + description="Diffusion training configuration. Only nested block is supported.", + ) diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py new file mode 100644 index 000000000..18a64023b --- /dev/null +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -0,0 +1,174 @@ +"""Callbacks for diffusion training.""" + +import logging +import sys + +import wandb +from colorama import Fore, Style +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from .generation import generate_samples + +# Simpler logger for more readable sample generation +logger = logging.getLogger(__name__) +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + logger.propagate = False +logger.setLevel(logging.INFO) + + +class DiffusionGenerationCallback(TrainerCallback): + """Callback for generating samples during diffusion training.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + if ( + state.global_step > 0 + and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0 + ): + if not self.trainer.state.is_world_process_zero: + return + + # Use eval dataloader if available, otherwise use train dataloader + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + except Exception: + dataloader = None + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + + # Generate samples + diffusion_cfg = self.trainer.cfg.diffusion + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=diffusion_cfg.num_generation_samples, + max_length=diffusion_cfg.generation_max_length, + num_diffusion_steps=diffusion_cfg.generation_steps, + temperature=diffusion_cfg.generation_temperature, + mask_token_id=diffusion_cfg.mask_token_id, + ) + + # Log samples + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples.""" + if not samples: + return + + logger.info("=" * 60) + logger.info("GENERATED SAMPLES") + logger.info("=" * 60) + + for i, sample_data in enumerate(samples, 1): + original = sample_data["original"] + masked = sample_data["masked"] + generated = sample_data["generated"] + mask_ratio = sample_data["mask_ratio"] + masked_tokens = sample_data["masked_tokens"] + total_tokens = sample_data["total_tokens"] + + logger.info(f"\nSample {i}:") + logger.info(f"\tOriginal ({total_tokens} tokens): {original}") + logger.info( + f"\tMasked ({masked_tokens}/{total_tokens} tokens, " + f"{mask_ratio:.1%}): {masked}" + ) + + try: + gen_ids = sample_data.get("generated_ids") + orig_ids = sample_data.get("orig_ids") + masked_positions = set(sample_data.get("masked_positions") or []) + if isinstance(gen_ids, list) and isinstance(orig_ids, list): + styles: list[str] = [] + for i, tid in enumerate(gen_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") + elif i < len(orig_ids): + styles.append("red") + else: + styles.append("normal") + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + spans: list[tuple[str, int, int]] = [] + if gen_ids: + cur = styles[0] + start = 0 + for i in range(1, len(gen_ids)): + s = styles[i] + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(gen_ids))) + + parts = [] + for style_name, a, b in spans: + chunk_text = self.trainer.processing_class.decode( + gen_ids[a:b], skip_special_tokens=False + ) + if style_name == "green": + parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + parts.append(chunk_text) + logger.info("\tGenerated:\n%s", "".join(parts)) + else: + logger.info(f"\tGenerated: {generated}") + except Exception: + logger.info(f"\tGenerated: {generated}") + + logger.info("=" * 60) + + if self.trainer.cfg.use_wandb: + if wandb.run is not None: + wandb.log( + { + "generated_samples": wandb.Table( + columns=[ + "step", + "original", + "masked", + "generated", + "mask_ratio", + "masked_tokens", + "total_tokens", + ], + data=[ + [ + step, + sample["original"], + sample["masked"], + sample["generated"], + f"{sample['mask_ratio']:.1%}", + sample["masked_tokens"], + sample["total_tokens"], + ] + for sample in samples + ], + ) + }, + step=step, + ) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py new file mode 100644 index 000000000..49e3cdfae --- /dev/null +++ b/src/axolotl/integrations/diffusion/generation.py @@ -0,0 +1,409 @@ +"""Sample generation utilities for diffusion training.""" + +import re +from typing import Any, List, Literal, Optional + +import torch + +from axolotl.utils.logging import get_logger + +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +def generate_samples( + model: torch.nn.Module, + tokenizer: Any, + dataloader: Optional[Any] = None, + num_generation_samples: int = 3, + max_length: int = 100, + num_diffusion_steps: int = 128, + temperature: float = 0.0, + mask_token_id: int = 32000, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, +) -> List[dict]: + """ + Generate text samples using the diffusion model by randomly masking sequences from + the given dataset and running the reverse diffusion process. + + Args: + model: The wrapped or unwrapped model + tokenizer: Tokenizer for encoding/decoding + dataloader: Validation dataloader (for sampling sequences) + num_generation_samples: Number of samples to generate + max_length: Maximum length of sequences to use + num_diffusion_steps: Number of diffusion steps for generation + temperature: Temperature for sampling (0.0 = deterministic) + mask_token_id: Token ID used for masking + + Returns: + List of dictionaries with original text, masked text, and generated text + """ + if dataloader is None: + LOG.warning("No validation dataloader provided, cannot generate samples") + return [] + + unwrapped_model = model.module if hasattr(model, "module") else model + training = unwrapped_model.training + unwrapped_model.eval() + + # Resolve device robustly (some modules don't expose `.device`) + device = getattr(unwrapped_model, "device", None) + if device is None: + try: + device = next(unwrapped_model.parameters()).device + except StopIteration: + device = torch.device("cpu") + generations = [] + + # Sample sequences from validation dataset + sampled_sequences = _sample_sequences_from_dataloader( + dataloader, num_generation_samples, max_length, device + ) + LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") + + # Generate samples using reverse diffusion process + with torch.no_grad(): + for sample in sampled_sequences: + if isinstance(sample, dict): + original_sequence = sample.get("input_ids") + labels_seq = sample.get("labels") + attn_seq = sample.get("attention_mask") + else: + original_sequence = sample + labels_seq = None + attn_seq = None + generation_result = generate( + unwrapped_model, + tokenizer, + original_sequence, + num_diffusion_steps, + temperature, + mask_token_id, + mode=mode, + completion_tokens=completion_tokens, + target_mask_ratio=target_mask_ratio, + labels=labels_seq, + attention_mask=attn_seq, + ) + generations.append(generation_result) + + # Restore prior training state + if training: + unwrapped_model.train() + else: + unwrapped_model.eval() + + return generations + + +def _sample_sequences_from_dataloader( + dataloader: Any, num_samples: int, max_length: int, device: torch.device +) -> List[Any]: + """Sample sequences from validation dataloader.""" + sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = [] + sample_count = 0 + + # Skip a random number of batches (we could be more clever about this) + skip_batches = torch.randint(0, 10, (1,)).item() + batch_count = 0 + + for batch in dataloader: + # Skip some batches for variety + if batch_count < skip_batches: + batch_count += 1 + continue + + if sample_count >= num_samples: + break + + batch_count += 1 + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask") + labels = batch.get("labels") + + # Randomly sample from sequences in this batch + batch_indices = torch.randperm(input_ids.size(0)).tolist() + + for i in batch_indices: + if sample_count >= num_samples: + break + + # Get actual sequence length (non-padded) + if attention_mask is not None: + seq_len = attention_mask[i].sum().item() + else: + seq_len = input_ids.size(1) + + if seq_len < 10: + continue + + # Determine truncation length + max_total = min(seq_len, max_length) + if labels is not None: + labels_i = labels[i][:seq_len] + answer_mask = labels_i != -100 + if not answer_mask.any(): + # No answer tokens; skip for SFT masking + continue + first_ans_idx = int( + torch.nonzero(answer_mask, as_tuple=False)[0].item() + ) + prompt_len = first_ans_idx + if prompt_len >= max_total: + # Prompt alone reaches cap; cannot include any answer + continue + remaining_answer = int(answer_mask[prompt_len:].sum().item()) + allowed_answer = max_total - prompt_len + take_answer = min(remaining_answer, allowed_answer) + if take_answer <= 0: + continue + actual_length = prompt_len + take_answer + else: + actual_length = max_total + + # Extract the (possibly truncated) sequence + sequence = input_ids[i][:actual_length].unsqueeze(0).to(device) + attn_seq = ( + attention_mask[i][:actual_length].unsqueeze(0).to(device) + if attention_mask is not None + else None + ) + if labels is not None: + labels_seq = labels[i][:actual_length].unsqueeze(0).to(device) + sampled_sequences.append( + { + "input_ids": sequence, + "labels": labels_seq, + "attention_mask": attn_seq, + } + ) + else: + if attn_seq is not None: + sampled_sequences.append( + {"input_ids": sequence, "attention_mask": attn_seq} + ) + else: + sampled_sequences.append(sequence) + sample_count += 1 + + return sampled_sequences + + +def generate( + model: torch.nn.Module, + tokenizer: Any, + original_sequence: torch.Tensor, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + *, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, + labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> dict: + """Generate a single sample using reverse diffusion.""" + # Get original text for comparison + original_text = tokenizer.decode( + original_sequence[0].cpu(), skip_special_tokens=True + ) + + # Build masked sequence + if ( + labels is not None + and labels.numel() > 0 + and (labels == -100).any() + and (labels != -100).any() + ): + # SFT case: completely mask all answer tokens (labels != -100) + total_tokens = original_sequence.size(1) + masked_indices = (labels != -100).to(dtype=torch.bool) + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + masked_tokens = int(masked_indices.sum().item()) + mask_ratio = masked_tokens / max(int(total_tokens), 1) + elif mode == "completion" and completion_tokens > 0: + # Append mask tokens to the right for completion + total_tokens = original_sequence.size(1) + int(completion_tokens) + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, -int(completion_tokens) :] = True + + append = torch.full( + (1, int(completion_tokens)), mask_token_id, device=original_sequence.device + ) + masked_sequence = torch.cat([original_sequence, append], dim=1) + masked_tokens = int(completion_tokens) + mask_ratio = masked_tokens / total_tokens + else: + # Apply random masking with optional fixed ratio + total_tokens = original_sequence.size(1) + if target_mask_ratio is None: + min_ratio, max_ratio = 0.1, 0.7 + target_mask_ratio = ( + torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio + ) + target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio))) + + # Create random mask indices + mask_positions = torch.randperm(total_tokens)[:target_masked_tokens] + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, mask_positions] = True + + # Create masked sequence + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + + # Calculate actual mask ratio + masked_tokens = masked_indices.sum().item() + mask_ratio = masked_tokens / total_tokens + + # Get masked text for comparison + masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False) + masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id) + + # Run reverse diffusion process + sequence = masked_sequence.clone() + attention_mask = create_bidirectional_attention_mask( + sequence, attention_mask, sample_packing=attention_mask is not None + ) + for step in range(num_diffusion_steps): + sequence = _diffusion_step( + model, + sequence, + step, + num_diffusion_steps, + temperature, + mask_token_id, + attention_mask, + ) + generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True) + + # Collect diagnostic info + final_ids = sequence[0].detach().cpu().tolist() + orig_ids_for_render = original_sequence[0].detach().cpu().tolist() + if masked_indices is not None: + masked_positions = ( + torch.where(masked_indices[0])[0].detach().cpu().tolist() + if masked_indices.ndim == 2 + else [] + ) + else: + masked_positions = [] + + result = { + "original": original_text, + "masked": masked_text, + "generated": generated_text, + "mask_ratio": mask_ratio, + "masked_tokens": masked_tokens, + "total_tokens": total_tokens, + "generated_ids": final_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids_for_render, + "formatted": ( + f"Original: '{original_text}' → Masked: '{masked_text}' " + f"({mask_ratio:.1%}) → Generated: '{generated_text}'" + ), + } + + return result + + +def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: + """Clean up masked text for display.""" + mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) + cleaned = masked_text.replace(mask_token_repr, "[MASK]") + + # Remove literal special token strings + if hasattr(tokenizer, "special_tokens_map"): + for token_value in tokenizer.special_tokens_map.values(): + if token_value and isinstance(token_value, str): + cleaned = cleaned.replace(token_value, "") + + # Normalize whitespace but preserve newlines + cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n") + cleaned = re.sub(r"[ \t]+", " ", cleaned) + cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip() + return cleaned + + +def _diffusion_step( + model: torch.nn.Module, + sequence: torch.Tensor, + step: int, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Perform a single diffusion step with remasking.""" + # Only process if there are masked tokens remaining + current_mask = sequence == mask_token_id + if not current_mask.any(): + return sequence + + # Create or use provided attention mask + if attention_mask is None: + batch_size, seq_len = sequence.shape + attention_mask = torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device + ) + + # Forward pass + outputs = model(input_ids=sequence, attention_mask=attention_mask) + logits = outputs.logits + + # Only sample at currently masked positions + if current_mask.any(): + masked_logits = logits[current_mask] + + # Apply temperature scaling + if temperature > 0: + scaled_logits = masked_logits / temperature + else: + scaled_logits = masked_logits + + # Suppress mask token in outputs + scaled_logits[:, mask_token_id] = -float("inf") + + if temperature > 0: + # Add Gumbel noise for sampling + gumbel_noise = -torch.log( + -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32)) + ) + gumbel_logits = scaled_logits + gumbel_noise + predicted_tokens = torch.argmax(gumbel_logits, dim=-1) + else: + predicted_tokens = torch.argmax(scaled_logits, dim=-1) + + # Calculate probabilities for confidence scoring + probs = torch.softmax(scaled_logits, dim=-1) + predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens] + + # Determine how many tokens to unmask this step + remaining_masked = current_mask.sum().item() + if step == num_diffusion_steps - 1: + num_to_unmask = remaining_masked + else: + unmask_ratio = 1.0 / (num_diffusion_steps - step) + num_to_unmask = max(1, int(remaining_masked * unmask_ratio)) + + # Select highest confidence predictions to unmask + if num_to_unmask >= remaining_masked: + sequence[current_mask] = predicted_tokens + else: + _, top_indices = predicted_token_probs.topk(num_to_unmask) + mask_positions = torch.where(current_mask)[1] + positions_to_unmask = mask_positions[top_indices] + sequence[0, positions_to_unmask] = predicted_tokens[top_indices] + + return sequence diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py new file mode 100644 index 000000000..c31f48b03 --- /dev/null +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -0,0 +1,41 @@ +"""Diffusion LM training plugin for Axolotl.""" + +from peft import PeftModel +from transformers import PreTrainedModel + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .trainer import DiffusionTrainer + +LOG = get_logger(__name__) + + +class DiffusionPlugin(BasePlugin): + """ + Plugin for diffusion language model training. + + This plugin enables diffusion-based training using the LLaDA approach, which uses + random masking and bidirectional attention to train language models. + """ + + def __init__(self): + super().__init__() + self.cfg = None + + def get_input_args(self) -> str: + """Returns the pydantic model for LLaDA plugin arguments.""" + return "axolotl.integrations.diffusion.DiffusionArgs" + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Perform actions after model is loaded.""" + self.cfg = cfg + + def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: + """Return custom trainer class for diffusion training.""" + return DiffusionTrainer + + def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): + """Configure trainer after creation.""" + trainer.set_config(cfg) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py new file mode 100644 index 000000000..42b2468f4 --- /dev/null +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -0,0 +1,301 @@ +"""Custom trainer for diffusion LM training.""" + +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import nn + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .callbacks import DiffusionGenerationCallback +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +class DiffusionTrainer(AxolotlTrainer): + """Custom trainer for diffusion LM training that overrides loss computation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = None + self._special_token_ids = None + + def set_config(self, config: DictDefault): + """Set config for diffusion training.""" + self.cfg = config + self._cache_special_token_ids() + self._resolve_mask_token_id() + + token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0)) + LOG.info(f"Diffusion: using mask_token_id={token_id}") + + if getattr(config.diffusion, "generate_samples", True): + generation_callback = DiffusionGenerationCallback(self) + self.add_callback(generation_callback) + + def _resolve_mask_token_id(self) -> None: + """Ensure mask_token_id is valid for the current tokenizer.""" + from .utils import resolve_mask_token_id + + tokenizer = getattr(self, "processing_class", None) + if tokenizer is None: + return + + mid = resolve_mask_token_id( + tokenizer, + self.cfg, + allow_add=True, + model=getattr(self, "model", None), + ) + try: + self.cfg.diffusion.mask_token_id = int(mid) + except Exception: + pass + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor], + return_outputs: bool = False, + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Override compute_loss to use diffusion loss.""" + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + labels = inputs.get("labels") + + if input_ids is None: + raise ValueError("input_ids is required for diffusion training") + + loss, outputs = self._compute_diffusion_loss( + model, input_ids, attention_mask, labels + ) + + if return_outputs: + return loss, outputs + return loss + + def _cache_special_token_ids(self): + """Cache special token IDs to avoid repeated tokenizer access.""" + if self.processing_class is None: + self._special_token_ids = set() + return + + tokenizer = self.processing_class + special_tokens = set() + + if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: + special_tokens.add(tokenizer.bos_token_id) + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + special_tokens.add(tokenizer.eos_token_id) + if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + special_tokens.add(tokenizer.pad_token_id) + + self._special_token_ids = special_tokens + + def _forward_process( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + eps: float = 1e-3, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward noising process. A timestep is sampled along the process, and tokens are + masked with probability determined by the configured noise schedule. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + eps: Small epsilon value for minimum masking probability. + + Returns: + noisy_batch: Input with some tokens masked. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Sample random timesteps for each sample in batch + t = torch.rand(batch_size, device=device) + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens if attention_mask is provided + if attention_mask is not None: + valid_mask = attention_mask.bool() + p_mask = p_mask * valid_mask.float() + + # Create mask to exclude special tokens + special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) + if self._special_token_ids: + for token_id in self._special_token_ids: + special_token_mask |= input_ids == token_id + + # Create random mask based on p_mask + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + masked_indices = masked_indices & ~special_token_mask + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask + + # Create masked input + mask_token_id = int(self.cfg.diffusion.mask_token_id) + mask_value = torch.full_like(input_ids, mask_token_id) + noisy_batch = torch.where(masked_indices, mask_value, input_ids) + + return noisy_batch, masked_indices, p_mask + + def _compute_diffusion_loss( + self, + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | Any]: + """ + Compute diffusion loss. + + Args: + model: The model to compute loss for. + input_ids: Ground truth token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + + Returns: + loss: Cross-entropy loss. + metrics: Dictionary of metrics. + """ + # Short-circuit empty sequences + if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0: + zero = torch.tensor( + 0.0, + device=(input_ids.device if input_ids is not None else None), + requires_grad=True, + ) + return zero, {} + + # If an attention_mask is provided and all positions are padding for every + # sample in this batch, skip the step. + if attention_mask is not None: + if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all(): + zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + return zero, {} + + # Apply forward process + noisy_batch, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, labels, self.cfg.diffusion.eps + ) + + # Create bidirectional attention mask + bidirectional_mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=self.cfg.sample_packing + ) + + # Forward pass + outputs = model( + input_ids=noisy_batch.long(), + attention_mask=bidirectional_mask, + ) + logits = outputs.logits + + if masked_indices.sum() > 0: + valid_indices = torch.where(masked_indices) + batch_indices, seq_indices = valid_indices + + masked_logits = logits[batch_indices, seq_indices] + masked_targets = input_ids[batch_indices, seq_indices] + masked_p_mask = p_mask[batch_indices, seq_indices] + + # Compute cross-entropy loss without reduction + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + if self.cfg.diffusion.importance_weighting: + masked_p_mask = masked_p_mask.float() + weighted_loss = token_loss / masked_p_mask + else: + weighted_loss = token_loss + + if labels is not None: + # For SFT data: normalize by answer token count per sample + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] + + # Get batch indices for masked tokens + masked_batch_indices = batch_indices + + # Sum losses per sample and divide by answer length + batch_size = input_ids.shape[0] + loss_per_sample = torch.zeros(batch_size, device=input_ids.device) + for i in range(batch_size): + sample_mask = masked_batch_indices == i + if sample_mask.sum() > 0: + sample_loss = weighted_loss[sample_mask].sum() + denom = answer_lengths[i].clamp(min=1.0) + loss_per_sample[i] = sample_loss / denom + + loss = loss_per_sample.mean() + else: + # Non-SFT: when importance weighting is enabled, use unbiased estimator + # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens + # for stable scaling across varying mask ratios. + if self.cfg.diffusion.importance_weighting: + loss = weighted_loss.sum() / ( + input_ids.shape[0] * input_ids.shape[1] + ) + else: + loss = weighted_loss.mean() + + ce_loss = token_loss.mean() + + # Compute accuracy on masked tokens + with torch.no_grad(): + pred_tokens = masked_logits.argmax(dim=-1) + accuracy = (pred_tokens == masked_targets).float().mean() + else: + loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + accuracy = torch.tensor(0.0, device=input_ids.device) + ce_loss = torch.tensor(0.0, device=input_ids.device) + masked_p_mask = torch.tensor(1.0, device=input_ids.device) + + avg_p_mask = ( + p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0 + ) + metrics = { + "loss": loss.item(), + "accuracy": accuracy.item(), + "mask_ratio": masked_indices.float().mean().item(), + "num_masked_tokens": (masked_indices.sum().item(), "sum"), + "avg_p_mask": avg_p_mask, + "ce_loss": ce_loss.item(), + } + + # If doing SFT training, log answer-specific metrics + if self.cfg.datasets is not None: + with torch.no_grad(): + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # type: ignore + total_answer_tokens = answer_mask.sum().item() # type: ignore + total_tokens = labels.numel() # type: ignore + metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1) + metrics["avg_answer_length"] = answer_lengths.mean().item() + + if self.cfg.diffusion.importance_weighting: + metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() + + train_eval: Literal["train", "eval"] = "train" if model.training else "eval" + self.store_metrics(metrics, train_eval=train_eval) + + return loss, outputs diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py new file mode 100644 index 000000000..47abf6fec --- /dev/null +++ b/src/axolotl/integrations/diffusion/utils.py @@ -0,0 +1,159 @@ +"""Shared utilities for diffusion integration.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from axolotl.utils.dict import DictDefault + + +def resolve_mask_token_id( + tokenizer: Any, + cfg: DictDefault, + *, + allow_add: bool, + model: Any | None = None, + default_token: str = "<|diffusion_mask|>", +) -> int: + """Resolve mask token id. Training may add a new special token; inference won't.""" + # Determine vocab size if available + vocab_size = None + if tokenizer is not None: + if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None: + try: + vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type] + except Exception: + vocab_size = None + elif hasattr(tokenizer, "__len__"): + try: + vocab_size = int(len(tokenizer)) + except Exception: + vocab_size = None + + # Use explicit id from config if provided + diffusion_cfg = getattr(cfg, "diffusion", None) + # Fallback to top-level attr names only if nested missing (shouldn't happen) + cfg_id = ( + getattr(diffusion_cfg, "mask_token_id", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_id", None) + ) + if isinstance(cfg_id, int) and cfg_id >= 0: + if vocab_size is None or cfg_id < vocab_size: + return int(cfg_id) + + def _existing_special_token_id(token_str: str | None) -> int | None: + """Attempt to resolve an existing special token string to a real ID.""" + if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"): + return None + try: + token_id = tokenizer.convert_tokens_to_ids(token_str) + except Exception: + return None + + if not isinstance(token_id, int) or token_id < 0: + return None + + # Ensure it's registered as special and not UNK, and within vocab + unk_id = getattr(tokenizer, "unk_token_id", None) + specials = set(getattr(tokenizer, "all_special_tokens", []) or []) + addl = set(getattr(tokenizer, "additional_special_tokens", []) or []) + is_special = token_str in specials or token_str in addl + in_vocab = vocab_size is None or token_id < vocab_size + if ( + (unk_id is not None and token_id == unk_id) + or not is_special + or not in_vocab + ): + return None + return token_id + + # Try mask token string if provided + token_str = ( + getattr(diffusion_cfg, "mask_token_str", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_str", None) + ) + for candidate in (token_str, default_token): + token_id = _existing_special_token_id(candidate) + if isinstance(token_id, int): + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(token_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(token_id) + except Exception: + pass + return int(token_id) + + # Optionally add and return a dedicated special token during training + if allow_add and hasattr(tokenizer, "add_special_tokens"): + token_to_add = token_str or default_token + try: + tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]}) + + # Resize embeddings if possible + if ( + model is not None + and hasattr(tokenizer, "__len__") + and hasattr(model, "resize_token_embeddings") + ): + try: + model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass + new_id = tokenizer.convert_tokens_to_ids(token_to_add) + if isinstance(new_id, int) and new_id >= 0: + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(new_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(new_id) + except Exception: + pass + return int(new_id) + except Exception: + pass + + # Fallback to unk or 0 (do not update cfg) + fallback = getattr(tokenizer, "unk_token_id", 0) or 0 + return int(fallback) + + +def create_bidirectional_attention_mask( + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sample_packing: bool = False, +) -> torch.Tensor: + """ + Create bidirectional attention mask to override default causal masking. + Handles sample-packed sequences where different samples are identified + by different attention mask values. + + Args: + input_ids: Input token ids [batch_size, seq_len] + attention_mask: Attention mask [batch_size, seq_len] + sample_packing: Whether sample packing is enabled + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len] + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + if attention_mask is None or not sample_packing: + return torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device + ) + + # Handle sample packing: tokens can only attend within their sample + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + + # Add head dimension: [batch_size, 1, seq_len, seq_len] + return bidirectional_mask.unsqueeze(1) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 989b34aee..bcde4bf96 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -14,6 +14,7 @@ from peft import ( PeftConfig, PeftMixedModel, PeftModel, + TaskType, get_peft_model, ) from transformers import PreTrainedModel @@ -101,6 +102,15 @@ def load_lora( if cfg.peft_trainable_token_indices: lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices + # Determine the correct PEFT task type + model_cls = type(model).__name__ + if "SequenceClassification" in model_cls: + task_type = TaskType.SEQ_CLS + elif "TokenClassification" in model_cls: + task_type = TaskType.TOKEN_CLS + else: + task_type = TaskType.CAUSAL_LM + lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, @@ -112,7 +122,7 @@ def load_lora( fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", - task_type="CAUSAL_LM", + task_type=task_type, **lora_config_kwargs, ) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index a9507d685..f438d6b61 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -673,6 +673,33 @@ class ModelLoader: return hf_ds_cfg + def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel: + """ + Load model with random initialization using from_config. + + Uses the selected loader when provided; otherwise falls back to the auto loader. + """ + loader = model_loader_class or self.auto_model_loader + if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]: + model = loader.from_config( + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + else: + model = loader(config=self.model_config) + + return model + + def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel: + """Load model from pretrained weights.""" + loader = model_loader_class or self.auto_model_loader + kwargs = { + "config": self.model_config, + "trust_remote_code": self.cfg.trust_remote_code or False, + **self.model_kwargs, + } + return loader.from_pretrained(self.base_model, **kwargs) + def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False @@ -687,7 +714,8 @@ class ModelLoader: if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: skip_move_to_device = True - # Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map + # Don't delete device_map for QLoRA + FSDP - it was set correctly in + # _set_device_map if ( "device_map" in self.model_kwargs and not self.is_qlora_and_fsdp_enabled @@ -716,6 +744,11 @@ class ModelLoader: or self.cfg.qlora_sharded_model_loading ) ): + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with sharded quantized loading. " + "Loading from pretrained weights instead." + ) quant_storage = self.cfg.torch_dtype quantization_config = getattr( self.model_config, "quantization_config", None @@ -731,33 +764,12 @@ class ModelLoader: quantization_config=quantization_config, ) skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # Please don't remove underscore binding without reading the fn docstring. - _ = self._configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader(config=self.model_config) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) elif self.model_type == "MambaLMHeadModel": + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with MambaLMHeadModel. " + "Loading from pretrained weights instead." + ) # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() @@ -770,41 +782,27 @@ class ModelLoader: self.base_model, **self.model_kwargs, ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - elif self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) else: - # Please don't remove underscore binding without reading the fn docstring. + # Please don't remove underscore binding without reading the fn docstring _ = self._configure_zero3_memory_efficient_loading() - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) + + if ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # Use model type from transformers + model_loader_class = getattr(transformers, self.model_type) + else: + # Use auto model loader (handles gptq and default cases) + model_loader_class = self.auto_model_loader + + if self.cfg.reinit_weights: + self.model = self._load_model_from_config(model_loader_class) + else: + self.model = self._load_model_from_pretrained(model_loader_class) + if is_deepspeed_zero3_enabled(): skip_move_to_device = True diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 044c278a3..a5a630cb5 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -3,8 +3,8 @@ Applies pre- and post-model load patches for various fixes and optimizations. """ -import os import importlib.util +import os from functools import cached_property import addict @@ -468,9 +468,10 @@ class PatchManager: def _apply_patch_deepspeed_zero3(self): try: - from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches + if self.cfg.activation_offloading is True and ( is_deepspeed_zero3_enabled() or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 3b38a33b7..d8ba02cb2 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -160,9 +160,11 @@ def get_state_dict(self, model, unwrap=True): state_dict[param_name] = param.cpu() torch.distributed.barrier() elif self.distributed_type == DistributedType.FSDP: - from torch.distributed.fsdp import FullStateDictConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, + ) full_state_dict_config = FullStateDictConfig( offload_to_cpu=True, rank0_only=True diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 65ccad533..678f65bee 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,11 +1,12 @@ """Flex attention monkey patch""" import sys -from packaging import version import torch import transformers +from packaging import version from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/deepspeed_utils.py b/src/axolotl/monkeypatch/deepspeed_utils.py index 6740f556b..d7e69e112 100644 --- a/src/axolotl/monkeypatch/deepspeed_utils.py +++ b/src/axolotl/monkeypatch/deepspeed_utils.py @@ -1,5 +1,6 @@ import importlib import importlib.util + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index f40fe6687..7a2bbd6f9 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -17,8 +17,8 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset LOG = get_logger(__name__) diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 788f13638..8b9e4e91d 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,14 +1,14 @@ """Init for `axolotl.utils.data` module.""" -from axolotl.utils.data.streaming import ( - encode_streaming, - wrap_streaming_dataset, -) from axolotl.utils.data.rl import prepare_preference_datasets from axolotl.utils.data.sft import ( get_dataset_wrapper, prepare_datasets, ) +from axolotl.utils.data.streaming import ( + encode_streaming, + wrap_streaming_dataset, +) from axolotl.utils.data.utils import md5 __all__ = [ diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 28732e01d..ba5aec2d6 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -16,7 +16,6 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from axolotl.prompters import Prompter from axolotl.utils.data.lock import FileLockLoader -from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.shared import ( create_train_validation_split, datasets_with_name_generator, @@ -27,6 +26,7 @@ from axolotl.utils.data.shared import ( save_preprocessed_dataset, try_load_from_hub, ) +from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, handle_long_seq_in_dataset, diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 751f7e253..192aca4e1 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -6,8 +6,6 @@ from importlib.metadata import version from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, -) -from accelerate.utils.environment import ( get_gpu_info, ) from packaging.version import Version, parse diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e4c1fdf29..d612ec8a5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -106,6 +106,12 @@ class AxolotlInputConfig( "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" }, ) + reinit_weights: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Reinitialize model weights randomly instead of loading pretrained weights" + }, + ) trainer_cls: str | None = Field( default=None, diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 49add8081..64018ca48 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -14,7 +14,6 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType - LOG = get_logger(__name__) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py new file mode 100644 index 000000000..cc3d8070b --- /dev/null +++ b/tests/e2e/test_diffusion.py @@ -0,0 +1,139 @@ +"""E2E smoke test for diffusion training plugin.""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +class TestDiffusion: + """Test case for diffusion training plugin.""" + + def test_diffusion_smoke_test(self, temp_dir): + """ + Smoke test for diffusion training to ensure the plugin loads and trains without + error. + """ + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 3, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": False, + }, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + def test_diffusion_sft_labels(self, temp_dir): + """Test that diffusion training properly handles SFT data with labels.""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 2, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": True, + }, + # Ensure we have proper SFT labels + "train_on_inputs": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + # Verify that the dataset has labels + sample = dataset_meta.train_dataset[0] + assert "labels" in sample, "SFT dataset should have labels" + + # Check that some labels are -100 (prompt tokens) + labels = sample["labels"] + if hasattr(labels, "tolist"): + labels = labels.tolist() + assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py new file mode 100644 index 000000000..141d8d150 --- /dev/null +++ b/tests/integrations/test_diffusion.py @@ -0,0 +1,274 @@ +"""Tests for diffusion trainer integration.""" + +# pylint: disable=redefined-outer-name,protected-access + +from unittest.mock import Mock + +import pytest +import torch + +from axolotl.integrations.diffusion import DiffusionTrainer +from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + tokenizer.pad_token_id = 0 + return tokenizer + + +@pytest.fixture +def diffusion_config(): + """Create a diffusion config.""" + return DictDefault( + { + "diffusion": { + "mask_token_id": 32000, + "eps": 1e-3, + "importance_weighting": False, + }, + "sample_packing": False, + } + ) + + +@pytest.fixture +def diffusion_trainer_instance(mock_tokenizer, diffusion_config): + """Create a diffusion trainer instance for testing methods directly.""" + # Create a minimal trainer instance just for testing methods + trainer = object.__new__(DiffusionTrainer) # Bypass __init__ + trainer.cfg = diffusion_config + trainer._special_token_ids = {0, 1, 2} # pad, bos, eos + trainer.processing_class = mock_tokenizer + trainer.store_metrics = Mock() # Mock metrics storage + return trainer + + +class TestDiffusionTrainer: + """Test the DiffusionTrainer class.""" + + def test_forward_process_basic(self, diffusion_trainer_instance): + """Test basic forward process without labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process(input_ids, eps=0.1) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that special tokens are not masked + special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) + assert not masked_indices[special_token_positions].any() + + # Check that mask token is applied + mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id + masked_positions = masked_indices + if masked_positions.any(): + assert (noisy_batch[masked_positions] == mask_token_id).all() + + def test_forward_process_with_labels(self, diffusion_trainer_instance): + """Test forward process with SFT labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process( + input_ids, labels=labels, eps=0.1 + ) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that only answer tokens can be masked (where labels != -100) + non_answer_mask = labels == -100 + + # No masking should occur on non-answer tokens + assert not masked_indices[non_answer_mask].any() + + # p_mask should be the same for all positions (sampled timestep), + # but masking is only applied to answer tokens + assert p_mask.shape == input_ids.shape + # Verify that masked_indices respects the answer mask + assert not masked_indices[non_answer_mask].any() + + def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): + """Test forward process with attention mask.""" + input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + + _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( + input_ids, attention_mask=attention_mask, eps=0.1 + ) + + # Check that padding tokens are not masked + padding_positions = attention_mask == 0 + assert not masked_indices[padding_positions].any() + assert (p_mask[padding_positions] == 0).all() + + def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): + """Test bidirectional attention mask without sample packing.""" + input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask(input_ids) + + # Should be all-to-all attention + expected_shape = (1, 1, 4, 4) + assert mask.shape == expected_shape + assert mask.all() + + def test_bidirectional_attention_mask_with_packing( + self, diffusion_trainer_instance + ): + """Test bidirectional attention mask with sample packing.""" + diffusion_trainer_instance.cfg.sample_packing = True + input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) + # Sample IDs: first sample (1), second sample (2) + attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=True + ) + + # Check that tokens within same sample can attend to each other + # but not across samples + assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other + assert mask[0, 0, 1, 2].item() + assert not mask[0, 0, 0, 3].item() # Can't attend across samples + assert not mask[0, 0, 2, 4].item() + assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other + + def test_compute_loss_basic(self, diffusion_trainer_instance): + """Test basic loss computation.""" + # Mock model that returns logits + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + assert outputs == mock_outputs + + # Check that metrics were stored + diffusion_trainer_instance.store_metrics.assert_called_once() + + def test_compute_loss_sft(self, diffusion_trainer_instance): + """Test loss computation with SFT labels.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + diffusion_trainer_instance.cfg.datasets = Mock() + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids, labels=labels + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + + # Check that SFT metrics were added + call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] + assert "answer_ratio" in call_args + assert "avg_answer_length" in call_args + + def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): + """Test loss computation when no tokens are masked.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 3 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + # Only special tokens (which won't be masked) + input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Loss should be zero when no tokens are masked + assert loss.item() == 0.0 + assert loss.requires_grad + + def test_cache_special_token_ids(self, mock_tokenizer): + """Test caching of special token IDs.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = mock_tokenizer + trainer._cache_special_token_ids() + assert trainer._special_token_ids == {0, 1, 2} + + def test_cache_special_token_ids_no_tokenizer(self): + """Test caching when no tokenizer is available.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = None + trainer._cache_special_token_ids() + + assert trainer._special_token_ids == set() + + def test_main_compute_loss_interface(self, diffusion_trainer_instance): + """Test the main compute_loss interface.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + mock_outputs.logits = torch.randn(1, 5, 1000) + mock_model.return_value = mock_outputs + mock_model.training = True + + inputs = { + "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), + "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), + } + + # Test without return_outputs + loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) + assert isinstance(loss, torch.Tensor) + + # Test with return_outputs + loss, outputs = diffusion_trainer_instance.compute_loss( + mock_model, inputs, return_outputs=True + ) + assert isinstance(loss, torch.Tensor) + assert outputs == mock_outputs + + def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): + """Test that missing input_ids raises ValueError.""" + mock_model = Mock() + inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} + + with pytest.raises(ValueError, match="input_ids is required"): + diffusion_trainer_instance.compute_loss(mock_model, inputs) diff --git a/tests/integrations/test_diffusion_callback.py b/tests/integrations/test_diffusion_callback.py new file mode 100644 index 000000000..3e8785fe0 --- /dev/null +++ b/tests/integrations/test_diffusion_callback.py @@ -0,0 +1,92 @@ +"""Tests for diffusion generation callback dataloader selection and triggering.""" + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from axolotl.integrations.diffusion import DiffusionGenerationCallback + + +class DummyTrainer: + """Minimal trainer double with required attributes/methods for the callback.""" + + def __init__(self, use_eval: bool): + # Config used by callback + self.cfg = SimpleNamespace( + diffusion=SimpleNamespace( + generation_interval=1, + num_generation_samples=1, + generation_max_length=32, + generation_steps=4, + generation_temperature=0.0, + mask_token_id=16, + ), + use_wandb=False, + ) + + # Model/tokenizer are passed through to generate_samples; not used here + self.model = Mock() + self.processing_class = Mock() + + # Datasets and loaders + self.eval_dataset = object() if use_eval else None + self._train_loader = object() + self._eval_loader = object() + + # State for world process check + self.state = SimpleNamespace(is_world_process_zero=True) + + # Track which loader was requested + self.requested: list[str] = [] + + def get_train_dataloader(self): + self.requested.append("train") + return self._train_loader + + def get_eval_dataloader(self): + self.requested.append("eval") + return self._eval_loader + + +@pytest.mark.parametrize("use_eval", [False, True]) +def test_callback_uses_correct_dataloader(monkeypatch, use_eval): + trainer = DummyTrainer(use_eval=use_eval) + callback = DiffusionGenerationCallback(trainer) + + captured = {} + + # Patch generate_samples in the callback module's namespace + def fake_generate_samples(**kwargs): + captured["dataloader"] = kwargs.get("dataloader") + # Return one dummy sample to exercise logging path + return [ + { + "original": "o", + "masked": "m", + "generated": "g", + "mask_ratio": 0.5, + "masked_tokens": 1, + "total_tokens": 2, + } + ] + + monkeypatch.setattr( + "axolotl.integrations.diffusion.callbacks.generate_samples", + fake_generate_samples, + ) + + # Trigger at step 1 (interval=1) + args = SimpleNamespace() + state = SimpleNamespace(global_step=1) + control = SimpleNamespace() + + callback.on_step_end(args=args, state=state, control=control) + + # Assert the expected dataloader path was used + if use_eval: + assert trainer.requested[0] == "eval" + assert captured["dataloader"] is trainer._eval_loader + else: + assert trainer.requested[0] == "train" + assert captured["dataloader"] is trainer._train_loader diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 54acbb5e4..2c1f9f936 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,12 +5,12 @@ from unittest.mock import Mock, patch from datasets import IterableDataset -from axolotl.utils.dict import DictDefault +from axolotl.utils.config import validate_config from axolotl.utils.data.sft import ( _prepare_streaming_dataset, prepare_datasets, ) -from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault class TestStreamingConfig(unittest.TestCase): From 9406c0c488277ef9d7152568b9fda50600c4221e Mon Sep 17 00:00:00 2001 From: salman Date: Thu, 11 Sep 2025 11:19:30 +0100 Subject: [PATCH 5/8] log before eval step (#3148) [skip-ci] --- src/axolotl/core/trainers/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3427a0b86..627f8e3f8 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -371,6 +371,11 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) + @override + def evaluate(self, *args, **kwargs): + LOG.info("Running evaluation step...") + return super().evaluate(*args, **kwargs) + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} From fcfc13d7106fe965054e46f0ad6b4f478cc5ba7c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 12 Sep 2025 14:45:18 +0700 Subject: [PATCH 6/8] feat(doc): update thinking and chat_template notes (#3114) [skip ci] * feat: update thinking and chat_template notes * fix: grammar --- examples/gpt-oss/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md index 0aa04a71c..fb6c67498 100644 --- a/examples/gpt-oss/README.md +++ b/examples/gpt-oss/README.md @@ -106,6 +106,16 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info. +### Thinking and chat_template masking conflict + +OpenAI’s Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotl’s `chat_template` masking. + +If your dataset has `thinking` content mid-turn, there are two paths we recommend: + +- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message). + +- Adjust your dataset to only have `thinking` content in the last turn. + ### TIPS - Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). From 0401a15888e480d03c7cd0fb6439b27e0dacd3a0 Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 12 Sep 2025 10:55:11 +0100 Subject: [PATCH 7/8] SEO go brrr (#3153) [skip-ci] --- CITATION.cff | 2 +- README.md | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index e6ecc7cb8..7bbfeec64 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,6 +1,6 @@ cff-version: 1.2.0 type: software -title: "Axolotl: Post-Training for AI Models" +title: "Axolotl: Open Source LLM Post-Training" message: "If you use this software, please cite it as below." authors: - name: "Axolotl maintainers and contributors" diff --git a/README.md b/README.md index d4794124a..1a033acd9 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Axolotl

+

+ A Free and Open Source LLM Fine-tuning Framework
+

GitHub License @@ -50,20 +53,21 @@ ## ✨ Overview -Axolotl is a tool designed to streamline post-training for various AI models. +Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs). Features: -- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models. -- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM). -- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference. +- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub. +- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support. +- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM). +- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference. - **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. -## 🚀 Quick Start +## 🚀 Quick Start - LLM Fine-tuning in Minutes **Requirements**: @@ -160,7 +164,7 @@ If you use Axolotl in your research or projects, please cite it as follows: ```bibtex @software{axolotl, - title = {Axolotl: Post-Training for AI Models}, + title = {Axolotl: Open Source LLM Post-Training}, author = {{Axolotl maintainers and contributors}}, url = {https://github.com/axolotl-ai-cloud/axolotl}, license = {Apache-2.0}, From 58d67bf98ddca63cb082374a04f8b2250ffc2c59 Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 12 Sep 2025 10:55:50 +0100 Subject: [PATCH 8/8] Migrate QAT API; fix `axolotl quantize` for QAT-ed models; add NVFP4 (#3107) --- .github/workflows/multi-gpu-e2e.yml | 2 +- .github/workflows/tests.yml | 2 +- docs/quantize.qmd | 8 + examples/llama-3/3b-qat-fsdp2-nvfp4.yaml | 64 ++++ examples/llama-3/3b-qat-fsdp2.yaml | 18 +- requirements.txt | 2 +- setup.py | 1 + src/axolotl/cli/args.py | 1 + src/axolotl/cli/quantize.py | 50 ++- src/axolotl/train.py | 19 +- src/axolotl/utils/quantization.py | 244 +++++++------- src/axolotl/utils/schemas/enums.py | 25 +- src/axolotl/utils/schemas/quantization.py | 54 ++-- tests/e2e/test_qat.py | 4 +- tests/e2e/test_quantization.py | 369 ++++++++++++---------- tests/e2e/utils.py | 30 ++ 16 files changed, 554 insertions(+), 339 deletions(-) create mode 100644 examples/llama-3/3b-qat-fsdp2-nvfp4.yaml diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 6492e5d3e..05f9e0761 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -44,7 +44,7 @@ jobs: cuda_version: 12.8.1 python_version: "3.11" pytorch: 2.8.0 - axolotl_extras: + axolotl_extras: fbgemm-gpu num_gpus: 2 nightly_build: "true" runs-on: [self-hosted, modal] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 337230d4a..cfd2c715d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -304,7 +304,7 @@ jobs: pytorch: 2.8.0 num_gpus: 1 gpu_type: "B200" - axolotl_extras: + axolotl_extras: fbgemm-gpu steps: - name: Checkout uses: actions/checkout@v4 diff --git a/docs/quantize.qmd b/docs/quantize.qmd index 113fcafbe..43c817a5b 100644 --- a/docs/quantize.qmd +++ b/docs/quantize.qmd @@ -51,3 +51,11 @@ axolotl quantize qat.yml ``` This ensures that an identical quantization configuration is used to quantize the model as was used to train it. + + +::: {.callout-note} + +If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it, +e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w` + +::: diff --git a/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml b/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml new file mode 100644 index 000000000..1ec809bbe --- /dev/null +++ b/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml @@ -0,0 +1,64 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + split: train[:95%] + +output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/dataset_prepared + +sequence_len: 8192 +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_checkpointing: true +gradient_accumulation_steps: 1 +micro_batch_size: 64 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true +bf16: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 + +special_tokens: + pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 35e3461e2..0c5a87891 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -15,20 +15,18 @@ liger_glu_activation: true liger_layer_norm: true liger_fused_linear_cross_entropy: true + datasets: - path: yahma/alpaca-cleaned type: alpaca + split: train[:95%] output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/qat_out/dataset_prepared -sample_packing: true - -sequence_len: 512 - -flex_attention: true -flex_attn_compile_kwargs: - dynamic: false - mode: max-autotune-no-cudagraphs +sample_packing: false +sequence_len: 8192 +flash_attention: true qat: activation_dtype: int8 @@ -67,7 +65,7 @@ fsdp: fsdp_config: fsdp_version: 2 fsdp_offload_params: false - fsdp_cpu_ram_efficient_loading: true + fsdp_cpu_ram_efficient_loading: false fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT @@ -76,6 +74,6 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: - pad_token: <|end_of_text|> + pad_token: <|finetune_right_pad_id|> # save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/requirements.txt b/requirements.txt index 1292a179a..6138707af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,7 +64,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.12.0 +torchao==0.13.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 diff --git a/setup.py b/setup.py index 4cbc562e0..3a44f0ae9 100644 --- a/setup.py +++ b/setup.py @@ -162,6 +162,7 @@ extras_require = { "llmcompressor": [ "llmcompressor==0.5.1", ], + "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], } install_requires, dependency_links, extras_require_build = parse_requirements( extras_require diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 396e9a8af..14dafa43f 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -115,6 +115,7 @@ class QuantizeCliArgs: quantize_embedding: Optional[bool] = field(default=None) group_size: Optional[int] = field(default=None) output_dir: Optional[str] = field(default=None) + hub_model_id: Optional[str] = field(default=None) @dataclass diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index b8a8de781..6838f47d8 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao from pathlib import Path from typing import Union -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from axolotl.cli.config import load_cfg from axolotl.loaders import load_tokenizer from axolotl.utils.logging import get_logger -from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq +from axolotl.utils.quantization import ( + TorchAOQuantDType, + get_quantization_config, + quantization_config_to_str, + quantize_model, +) LOG = get_logger(__name__) @@ -43,13 +48,13 @@ def do_quantize( "No quantization configuration found. Please specify either qat or quantization in your config file." ) - model_path = cli_args.get("model_path") or cfg.output_dir + model_path = cli_args.get("base_model") or cfg.output_dir if weight_dtype := cli_args.get("weight_dtype"): - weight_dtype = TorchIntDType[weight_dtype] + weight_dtype = TorchAOQuantDType.from_string(weight_dtype) else: weight_dtype = quantize_cfg.weight_dtype if activation_dtype := cli_args.get("activation_dtype"): - activation_dtype = TorchIntDType[activation_dtype] + activation_dtype = TorchAOQuantDType.from_string(activation_dtype) else: activation_dtype = quantize_cfg.activation_dtype group_size = cli_args.get("group_size") or quantize_cfg.group_size @@ -57,10 +62,15 @@ def do_quantize( cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding ) output_dir = cli_args.get("output_dir") or cfg.output_dir + hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id - LOG.info(f"Loading model from {model_path}...") + LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) - model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + config = AutoConfig.from_pretrained(model_path) + torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", torch_dtype=torch_dtype + ) LOG.info( f"Quantizing model with configuration: \n" @@ -70,11 +80,21 @@ def do_quantize( f"\tquantize_embedding: {quantize_embedding}" ) - quantize_model_for_ptq( + quantize_model( model, weight_dtype, group_size, activation_dtype, quantize_embedding ) - LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") + quantization_config = get_quantization_config( + weight_dtype, activation_dtype, group_size + ) + + ao_config = TorchAoConfig( + quant_type=quantization_config, + include_input_output_embeddings=quantize_embedding, + ) + model.config.quantization_config = ao_config + + LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") model.save_pretrained( str(Path(output_dir) / "quantized"), safe_serialization=False, @@ -86,4 +106,14 @@ def do_quantize( progressbar=True, save_jinja_files=cfg.tokenizer_save_jinja_files, ) - LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...") + + if hub_model_id: + hub_model_id = ( + hub_model_id.rstrip("-") + + f"-{quantization_config_to_str[type(quantization_config)]}" + ) + model.push_to_hub(hub_model_id, safe_serialization=False) + tokenizer.push_to_hub(hub_model_id) + LOG.info(f"Quantized model pushed to: {hub_model_id}.") + + LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e8e314579..b0482bb1e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.integrations.base import PluginManager -from axolotl.loaders import ( - ModelLoader, - load_processor, - load_tokenizer, -) +from axolotl.loaders import ModelLoader, load_processor, load_tokenizer from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -234,16 +230,15 @@ def save_trained_model( # handle QAT if cfg.qat: - from axolotl.utils.quantization import convert_qat_model_for_ptq + from axolotl.utils.quantization import convert_qat_model - LOG.info("Processing QAT model for saving...") - convert_qat_model_for_ptq( + convert_qat_model( model, quantize_embedding=cfg.qat.quantize_embedding, ) LOG.info( - "QAT modules have been converted for PTQ. Please ensure you quantize " - "your model weights with `axolotl quantize`." + "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`" + " with the same config which you used for training." ) # Handle ReLoRA early return case if cfg.relora: @@ -337,9 +332,7 @@ def save_trained_model( if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: # TODO: add integration support so this can be implemented completely within the plugin - from axolotl.integrations.llm_compressor.utils import ( - save_compressed_model, - ) + from axolotl.integrations.llm_compressor.utils import save_compressed_model save_compressed_model( model=model, diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index f9a30b660..6c29a5442 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao. """ import torch -from torch import nn +from packaging import version from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, + QATConfig, ) from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, - _is_linear, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + +quantization_config_to_str = { + Int8DynamicActivationInt4WeightConfig: "int8int4", + Float8DynamicActivationFloat8WeightConfig: "fp8fp8", + Float8DynamicActivationInt4WeightConfig: "fp8int4", +} + +if version.parse(torch.__version__) >= version.parse("2.8.0"): + try: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" + except: + pass + + # int4 weight config imports will fail on machines with fbgemm-gpu installed + # without a CUDA runtime available so we do this safely + try: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + quantization_config_to_str[Int4WeightOnlyConfig] = "int4" + except: + pass -def get_ptq_config( - weight_dtype: TorchIntDType, - activation_dtype: TorchIntDType | None = None, +def get_quantization_config( + weight_dtype: TorchAOQuantDType, + activation_dtype: TorchAOQuantDType | None = None, group_size: int | None = None, ) -> AOBaseConfig: """ @@ -45,44 +62,101 @@ def get_ptq_config( or if the group size is not specified for int8 or int4 weight only quantization. """ if activation_dtype is None: - if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr] - return UIntXWeightOnlyConfig( - dtype=weight_dtype.value, - group_size=group_size, - set_inductor_config=False, - ) - if weight_dtype == TorchIntDType.int8: - if group_size is None: - raise ValueError( - "group_size must be specified for int8 weight only quantization" - ) - return Int8WeightOnlyConfig( - group_size=group_size, - ) - if weight_dtype == TorchIntDType.int4: - if group_size is None: - raise ValueError( - "group_size must be specified for int4 weight only quantization" - ) - return Int4WeightOnlyConfig( - group_size=group_size, - ) - if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4: - return Int4DynamicActivationInt4WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8: - return Int8DynamicActivationInt8WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4: - return Int8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.int8: + raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.") + if weight_dtype == TorchAOQuantDType.int4: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + if group_size is not None: + return Int4WeightOnlyConfig(group_size=group_size, version=2) + else: + return Int4WeightOnlyConfig(version=2) + if ( + activation_dtype == TorchAOQuantDType.int4 + and weight_dtype == TorchAOQuantDType.int4 + ): + raise ValueError( + "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int8 + ): + raise ValueError( + "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int4 + ): + if group_size is not None: + return Int8DynamicActivationInt4WeightConfig(group_size=group_size) + else: + return Int8DynamicActivationInt4WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.float8_e4m3fn + ): + return Float8DynamicActivationFloat8WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.int4 + ): + return Float8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.nvfp4: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + if group_size is not None and group_size != 16: + raise ValueError("NVFP4 quantization must use a group_size of 16") + return NVFP4InferenceConfig() raise ValueError( f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" ) +def quantize_model( + model, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, + quantize_embedding: bool | None = None, +): + """ + This function is used to quantize a model. + + Args: + model: The model to quantize. + weight_dtype: The dtype to use for weight quantization. + group_size: The group size to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + quantize_embedding: Whether to quantize the model's embedding weights. + + """ + linear_ptq_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=activation_dtype, + group_size=group_size, + ) + quantize_(model, linear_ptq_config) + if quantize_embedding: + # activation fake quantization is not supported for embedding layers + embedding_quantize_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=None, + group_size=group_size, + ) + quantize_( + model, + embedding_quantize_config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) + + def prepare_model_for_qat( model, - weight_dtype: TorchIntDType, - group_size: int, - activation_dtype: TorchIntDType | None = None, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, quantize_embedding: bool = False, ): """ @@ -100,86 +174,40 @@ def prepare_model_for_qat( Raises: ValueError: If the activation/weight dtype combination is invalid. """ - if activation_dtype: - activation_config = FakeQuantizeConfig( - dtype=activation_dtype.value, granularity="per_token", is_symmetric=False - ) - weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size) - linear_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None if activation_dtype is None else activation_config, - weight_config=weight_config, - ) - quantize_(model, linear_quantize_config) - if quantize_embedding: - # activation fake quantization is not supported for embedding layers - embedding_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None, - weight_config=weight_config, - ) - quantize_( - model, - embedding_quantize_config, - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), - ) - - -def quantize_model_for_ptq( - model, - weight_dtype: TorchIntDType, - group_size: int | None = None, - activation_dtype: TorchIntDType | None = None, - quantize_embedding: bool | None = None, -): - """ - This function is used to quantize a model for post-training quantization. - It swaps the model's linear layers with fake quantized linear layers. - If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights. - - Args: - model: The model to quantize. - weight_dtype: The dtype to use for weight quantization. - group_size: The group size to use for weight quantization. - activation_dtype: The dtype to use for activation quantization. - quantize_embedding: Whether to quantize the model's embedding weights. - - """ - linear_ptq_config = get_ptq_config( + base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=activation_dtype, group_size=group_size, ) - quantize_(model, linear_ptq_config) + qat_config = QATConfig(base_config) + quantize_(model, qat_config) if quantize_embedding: - embedding_quantize_config = get_ptq_config( + # activation fake quantization is not supported for embedding layers + embedding_base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=None, group_size=group_size, ) + embedding_qat_config = QATConfig(embedding_base_config) quantize_( model, - embedding_quantize_config, + embedding_qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) -def convert_qat_model_for_ptq( +def convert_qat_model( model, - *, - quantize_embedding: bool | None = None, + quantize_embedding: bool = False, ): """ - This function is used to convert a swap fake-quantized modules in a model - which has been trained with QAT back to the original modules, ready for PTQ. - - Args: - model: The model to convert. - quantize_embedding: Whether to quantize the model's embedding weights. + This function converts a QAT model which has fake quantized layers back to the original model. """ + config = QATConfig(step="convert") + quantize_(model, config) if quantize_embedding: - - def filter_fn(m, _): - return isinstance(m, nn.Embedding) or _is_linear(m) - - else: - filter_fn = _is_linear - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) + quantize_( + model, + config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 8f4718aa9..bcd03e1a2 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -5,18 +5,21 @@ from enum import Enum import torch -class TorchIntDType(Enum): - """Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" +class TorchAOQuantDType(Enum): + int4 = torch.int4 + int8 = torch.int8 + float8_e4m3fn = torch.float8_e4m3fn + nvfp4 = "nvfp4" - uint1 = getattr(torch, "uint1", None) - uint2 = getattr(torch, "uint2", None) - uint3 = getattr(torch, "uint3", None) - uint4 = getattr(torch, "uint4", None) - uint5 = getattr(torch, "uint5", None) - uint6 = getattr(torch, "uint6", None) - uint7 = getattr(torch, "uint7", None) - int4 = getattr(torch, "int4", None) - int8 = getattr(torch, "int8", None) + def from_string(str): + if str == "int4": + return TorchAOQuantDType.int4 + if str == "int8": + return TorchAOQuantDType.int8 + if str in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if str == "nvfp4": + return TorchAOQuantDType.nvfp4 class RLType(str, Enum): diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index 090640c7b..a7c130574 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -6,7 +6,23 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + + +def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: + if v is None: + return None + if v == "int4": + return TorchAOQuantDType.int4 + if v == "int8": + return TorchAOQuantDType.int8 + if v in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if v == "nvfp4": + return TorchAOQuantDType.nvfp4 + raise ValueError( + f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}" + ) class QATConfig(BaseModel): @@ -14,13 +30,13 @@ class QATConfig(BaseModel): QAT Config Schema """ - activation_dtype: TorchIntDType | None = Field( + activation_dtype: TorchAOQuantDType | None = Field( default=None, - description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description="Fake quantization layout to use for activation quantization.", ) - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, - description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"', + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) quantize_embedding: bool | None = Field( default=False, description="Quantize embedding" @@ -35,12 +51,8 @@ class QATConfig(BaseModel): @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) class PTQConfig(BaseModel): @@ -48,13 +60,13 @@ class PTQConfig(BaseModel): PTQ Config Schema """ - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, - description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8", + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) - activation_dtype: TorchIntDType | None = Field( + activation_dtype: TorchAOQuantDType | None = Field( default=None, - description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description="Fake quantization layout to use for activation quantization.", ) quantize_embedding: bool | None = Field( default=None, description="Whether to quantize the embedding layer." @@ -66,9 +78,5 @@ class PTQConfig(BaseModel): @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index 7d41dfb50..2f8398ef7 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -43,7 +43,7 @@ class TestQATLlama: "qat": { "quantize_embedding": True, "activation_dtype": "int8", - "weight_dtype": "int8", + "weight_dtype": "int4", "group_size": 8, }, "num_epochs": 1, @@ -111,7 +111,7 @@ class TestQATLlama: "qat": { "quantize_embedding": True, "activation_dtype": "int8", - "weight_dtype": "int8", + "weight_dtype": "int4", "group_size": 8, }, "save_first_step": False, diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index cfbdfec38..b64aef51a 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn -from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor -from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.linear_activation_quantized_tensor import ( - LinearActivationQuantizedTensor, -) +from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt4WeightConfig, ) +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from transformers import AutoModelForCausalLM from transformers.trainer_callback import TrainerState from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.quantization import ( - convert_qat_model_for_ptq, - get_ptq_config, + convert_qat_model, + get_quantization_config, prepare_model_for_qat, - quantize_model_for_ptq, + quantize_model, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType from axolotl.utils.schemas.quantization import QATConfig -from tests.e2e.utils import require_torch_2_6_0 +from tests.e2e.utils import ( + require_torch_2_8_0, + requires_cuda_ge_8_9, + requires_sm_ge_100, +) @pytest.fixture() def model(): dummy_model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceTB/SmolLM2-135M", - device_map="cuda", + "Qwen/Qwen2-0.5B", + device_map="auto", torch_dtype=torch.bfloat16, ) with torch.device(dummy_model.device): @@ -48,45 +47,56 @@ def model(): dummy_model.model.embed_tokens.weight.shape[1], dtype=dummy_model.model.embed_tokens.weight.dtype, ) - return dummy_model + yield dummy_model + del dummy_model ptq_config_test_cases = [ - # weight_dtype, activation_dtype, group_size, expected_type, expected_params + # weight_dtype, activation_dtype, group_size, expected_type ( - TorchIntDType.uint4, + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, None, - None, - UIntXWeightOnlyConfig, - {"dtype": torch.uint4, "group_size": None}, - ), - (TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}), - (TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}), - ( - TorchIntDType.int4, - TorchIntDType.int4, - None, - Int4DynamicActivationInt4WeightConfig, - {}, + Int8DynamicActivationInt4WeightConfig, ), ( - TorchIntDType.int8, - TorchIntDType.int8, + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, None, - Int8DynamicActivationInt8WeightConfig, - {}, + Float8DynamicActivationFloat8WeightConfig, + ), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.float8_e4m3fn, + None, + Float8DynamicActivationInt4WeightConfig, ), ] ptq_test_cases = [ - # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception - (TorchIntDType.int8, None, 8, False, None), - (TorchIntDType.int4, None, 4, True, None), - (TorchIntDType.uint4, None, 8, False, None), - (TorchIntDType.int4, TorchIntDType.int4, 8, False, None), - (TorchIntDType.int8, TorchIntDType.int8, 8, True, None), - (TorchIntDType.int8, None, None, False, ValueError), - (TorchIntDType.int4, None, None, False, ValueError), + # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class + (TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, + 8, + False, + None, + LinearActivationQuantizedTensor, + ), + # ( + # TorchAOQuantDType.int4, + # TorchAOQuantDType.float8_e4m3fn, + # None, + # False, + # None, + # Int4Tensor, + # ), + (TorchAOQuantDType.int4, None, None, False, None, Int4Tensor), + # Deprecated configs + (TorchAOQuantDType.int8, None, 8, False, ValueError, None), + (TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None), + (TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None), ] @@ -96,44 +106,132 @@ class TestQuantization: """ @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,expected_type,expected_params", + "weight_dtype,activation_dtype,group_size,expected_type", ptq_config_test_cases, ) - @require_torch_2_6_0 + @requires_cuda_ge_8_9 + @require_torch_2_8_0 def test_get_ptq_config( - self, weight_dtype, activation_dtype, group_size, expected_type, expected_params + self, weight_dtype, activation_dtype, group_size, expected_type ): - config = get_ptq_config(weight_dtype, activation_dtype, group_size) - + config = get_quantization_config(weight_dtype, activation_dtype, group_size) assert isinstance(config, expected_type) - for param_name, param_value in expected_params.items(): - if isinstance(param_value, (PerAxis, PerGroup)): - if isinstance(param_value, PerAxis): - assert isinstance(getattr(config, param_name), PerAxis) - assert getattr(config, param_name).axis == param_value.axis - else: - assert isinstance(getattr(config, param_name), PerGroup) - assert ( - getattr(config, param_name).group_size == param_value.group_size - ) - else: - assert getattr(config, param_name) == param_value + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_get_ptq_config_int4_weight_only(self): + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + config = get_quantization_config(TorchAOQuantDType.int4, None, 4) + assert isinstance(config, Int4WeightOnlyConfig) @pytest.mark.parametrize( - "weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4] + "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class", + ptq_test_cases, ) + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_quantize_model_for_ptq( + self, + model, + weight_dtype, + activation_dtype, + group_size, + quantize_embedding, + expected_exception, + expected_tensor_class, + ): + if expected_exception: + with pytest.raises(expected_exception): + quantize_model( + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, + ) + else: + quantize_model( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + if quantize_embedding: + assert isinstance( + model.model.embed_tokens.weight, expected_tensor_class + ), "Embedding weight should be quantized" + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, expected_tensor_class) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_fp8( + self, + model, + ): + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, + ) + + quantize_model( + model, + TorchAOQuantDType.float8_e4m3fn, + None, + TorchAOQuantDType.float8_e4m3fn, + ) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, Float8Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs + ) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_nvfp4( + self, + model, + ): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + QuantizeTensorToNVFP4Kwargs, + ) + + quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, NVFP4Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs + ) + @pytest.mark.parametrize( - "activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8] + "weight_dtype,activation_dtype,group_size,quantize_embedding", + [ + (TorchAOQuantDType.int4, None, 8, False), + (TorchAOQuantDType.int4, None, 16, True), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True), + ( + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, + None, + False, + ), + (TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True), + ], ) - @pytest.mark.parametrize("group_size", [4, 8]) - @pytest.mark.parametrize("quantize_embedding", [False, True]) - @require_torch_2_6_0 + @require_torch_2_8_0 + @requires_cuda_ge_8_9 def test_prepare_model_for_qat( self, model, weight_dtype, activation_dtype, group_size, quantize_embedding ): prepare_model_for_qat( - model, weight_dtype, group_size, activation_dtype, quantize_embedding + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, ) if quantize_embedding: assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) @@ -142,17 +240,19 @@ class TestQuantization: model.model.embed_tokens.weight_fake_quantizer.config.dtype == weight_dtype.value ) - assert ( - model.model.embed_tokens.weight_fake_quantizer.config.group_size - == group_size - ) + if group_size: + assert ( + model.model.embed_tokens.weight_fake_quantizer.config.group_size + == group_size + ) for child in list(model.children()): if isinstance(child, torch.nn.Linear): assert isinstance(child, FakeQuantizedLinear) assert hasattr(child, "weight_fake_quantizer") assert child.weight_fake_quantizer.config.dtype == weight_dtype.value - assert child.weight_fake_quantizer.config.group_size == group_size + if group_size: + assert child.weight_fake_quantizer.config.group_size == group_size if activation_dtype: assert hasattr(child, "activation_fake_quantizer") assert ( @@ -162,49 +262,40 @@ class TestQuantization: else: assert child.activation_fake_quantizer is None - @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception", - ptq_test_cases, - ) - @require_torch_2_6_0 - def test_quantize_model_for_ptq( - self, - model, - weight_dtype, - activation_dtype, - group_size, - quantize_embedding, - expected_exception, - ): - if expected_exception: - with pytest.raises(expected_exception): - quantize_model_for_ptq( - model, - weight_dtype, - group_size, - activation_dtype, - quantize_embedding, - ) - else: - quantize_model_for_ptq( - model, weight_dtype, group_size, activation_dtype, quantize_embedding - ) - if quantize_embedding: - assert isinstance( - model.model.embed_tokens.weight, AffineQuantizedTensor - ), "Embedding weight should be quantized" - for child in list(model.children()): - if isinstance(child, torch.nn.Linear): - if activation_dtype: - assert isinstance( - child.weight, LinearActivationQuantizedTensor - ), ( - "Linear weight should be quantized with activation quantization" - ) - else: - assert isinstance(child.weight, AffineQuantizedTensor), ( - "Linear weight should be quantized without activation quantization" - ) + @require_torch_2_8_0 + @requires_cuda_ge_8_9 + def test_convert_qat_model(self, model): + config = QATConfig( + weight_dtype="int4", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + ) + + # quantize model for qat + prepare_model_for_qat( + model, + config.weight_dtype, + config.group_size, + config.activation_dtype, + config.quantize_embedding, + ) + + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert isinstance(model.lm_head, FakeQuantizedLinear) + + # apply conversion + convert_qat_model( + model, + config.quantize_embedding, + ) + # ensure modules have been swapped out + assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert not isinstance(model.lm_head, FakeQuantizedLinear) + + # ensure weights have been quantized + assert isinstance(model.model.embed_tokens.weight, nn.Parameter) + assert isinstance(model.lm_head.weight, nn.Parameter) class TestQuantizationCallback: @@ -218,10 +309,10 @@ class TestQuantizationCallback: global_step=0, ) - @require_torch_2_6_0 + @require_torch_2_8_0 def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -268,10 +359,10 @@ class TestQuantizationCallback: assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - @require_torch_2_6_0 + @require_torch_2_8_0 def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -304,43 +395,3 @@ class TestQuantizationCallback: # quantization should be enabled from the get-go assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - - -class TestConvertQATModelForPTQ: - """ - Test convert_qat_model_for_ptq - """ - - @require_torch_2_6_0 - def test_convert_qat_model_for_ptq(self, model): - config = QATConfig( - weight_dtype="int8", - activation_dtype="int8", - group_size=8, - quantize_embedding=True, - ) - - # quantize model for qat - prepare_model_for_qat( - model, - config.weight_dtype, - config.group_size, - config.activation_dtype, - config.quantize_embedding, - ) - - assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert isinstance(model.lm_head, FakeQuantizedLinear) - - # apply conversion - convert_qat_model_for_ptq( - model, - quantize_embedding=config.quantize_embedding, - ) - # ensure modules have been swapped out - assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert not isinstance(model.lm_head, FakeQuantizedLinear) - - # ensure weights have been quantized - assert isinstance(model.model.embed_tokens.weight, nn.Parameter) - assert isinstance(model.lm_head.weight, nn.Parameter) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 7db6cf74e..a2dd8bc5e 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case): return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case) +def require_torch_2_8_0(test_case): + """ + Decorator marking a test that requires torch >= 2.7.0 + """ + + def is_min_2_8_0(): + torch_version = version.parse(torch.__version__) + return torch_version >= version.parse("2.8.0") + + return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case) + + def require_torch_lt_2_6_0(test_case): """ Decorator marking a test that requires torch < 2.6.0 @@ -128,6 +140,24 @@ def require_llmcompressor(test_case): )(test_case) +def requires_sm_ge_100(test_case): + is_sm_ge_100 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (10, 0) + ) + return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case) + + +def requires_cuda_ge_8_9(test_case): + is_cuda_ge_8_9 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case) + + def is_hopper(): compute_capability = torch.cuda.get_device_capability() return compute_capability == (9, 0)