From 9640338d37d0398cd3c0c0ab6e629b6dd9dcd5d3 Mon Sep 17 00:00:00 2001
From: salman Generated:\n(no output)
"
+
+ def _style_for(i: int, tid: int) -> str:
+ if i in masked_positions:
+ if i < len(orig_ids) and tid == orig_ids[i]:
+ return "green"
+ if i < len(orig_ids):
+ return "red"
+ return "normal"
+ same = i < len(orig_ids) and tid == orig_ids[i]
+ return "dim" if same else "normal"
+
+ # Group contiguous spans by style to reduce HTML size
+ spans: list[tuple[str, int, int]] = []
+ if generated_ids:
+ cur = _style_for(0, generated_ids[0])
+ start = 0
+ for i in range(1, len(generated_ids)):
+ s = _style_for(i, generated_ids[i])
+ if s != cur:
+ spans.append((cur, start, i))
+ cur, start = s, i
+ spans.append((cur, start, len(generated_ids)))
+
+ html_parts = []
+ for style_name, a, b in spans:
+ txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
+ if style_name == "green":
+ html_parts.append(f'{txt}')
+ elif style_name == "red":
+ html_parts.append(f'{txt}')
+ elif style_name == "dim":
+ html_parts.append(f'{txt}')
+ else:
+ html_parts.append(txt)
+
+ legend = (
+ 'Generated:\n'
+ + "".join(html_parts)
+ + "
"
+ )
+
+
+def launch_diffusion_gradio_ui(
+ *,
+ model,
+ tokenizer,
+ cfg: DictDefault,
+ prompter_module=None,
+ chat_template_str: str | None = None,
+):
+ """Build and launch a simple Gradio UI for diffusion inference."""
+ with gr.Blocks(
+ title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
+ ) as demo:
+ gr.Markdown(
+ """
+ ## Axolotl Diffusion Inference
+ - Mode "Random" masks tokens at a target ratio and fills them.
+ - Mode "Completion" appends N masked tokens at the end and fills them.
+ """
+ )
+
+ with gr.Row():
+ mode = gr.Radio(
+ choices=["random", "completion"],
+ value="random",
+ label="Mode",
+ )
+ mask_ratio = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ step=0.05,
+ value=0.4,
+ label="Mask ratio (random mode)",
+ interactive=True,
+ )
+ completion_tokens = gr.Number(
+ value=64,
+ precision=0,
+ label="Completion tokens (completion mode)",
+ interactive=True,
+ visible=False,
+ )
+
+ instruction = gr.Textbox(label="Instruction", lines=6)
+ run_btn = gr.Button("Generate")
+
+ masked_preview = gr.Textbox(label="Masked preview", lines=6)
+ html_out = gr.HTML(label="Generated")
+
+ def _toggle_controls(selected_mode: str):
+ return (
+ gr.update(visible=(selected_mode == "random")),
+ gr.update(visible=(selected_mode == "completion")),
+ )
+
+ mode.change(
+ _toggle_controls,
+ inputs=[mode],
+ outputs=[mask_ratio, completion_tokens],
+ )
+
+ def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
+ if not instruction_text:
+ return "", "Generated:\n(no output)
"
+
+ if prompter_module:
+ prompt: str = next(
+ prompter_module().build_prompt(
+ instruction=instruction_text.strip("\n")
+ )
+ )
+ else:
+ prompt = instruction_text.strip()
+
+ info = run_diffusion(
+ model=model,
+ tokenizer=tokenizer,
+ cfg=cfg,
+ prompt=prompt,
+ chat_template_str=chat_template_str,
+ mode=selected_mode,
+ target_mask_ratio=mratio if selected_mode == "random" else None,
+ completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
+ )
+
+ masked_text = info.get("masked_text")
+ mask_ratio_val = info.get("mask_ratio")
+ generated_ids = info.get("generated_ids")
+ masked_positions = info.get("masked_positions") or set()
+ orig_ids = info.get("orig_ids") or []
+
+ preview = (
+ f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
+ if masked_text is not None and mask_ratio_val is not None
+ else ""
+ )
+ html = render_html(
+ generated_ids=generated_ids,
+ orig_ids=orig_ids,
+ masked_positions=masked_positions,
+ tokenizer=tokenizer,
+ )
+ return preview, html
+
+ run_btn.click(
+ _gen,
+ inputs=[instruction, mode, mask_ratio, completion_tokens],
+ outputs=[masked_preview, html_out],
+ )
+
+ demo.queue().launch(
+ show_api=False,
+ share=cfg.get("gradio_share", True),
+ server_name=cfg.get("gradio_server_name", "127.0.0.1"),
+ server_port=cfg.get("gradio_server_port", None),
+ )
diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py
index ee6383d47..f7f350e1a 100644
--- a/src/axolotl/core/builders/causal.py
+++ b/src/axolotl/core/builders/causal.py
@@ -7,7 +7,11 @@ from pathlib import Path
from typing import Type, Union
import transformers
-from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
+from transformers import (
+ DataCollatorWithFlattening,
+ EarlyStoppingCallback,
+ Trainer,
+)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -23,15 +27,16 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
+ LossWatchDogCallback,
+ SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
- LossWatchDogCallback,
- SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
+from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -39,7 +44,6 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
-from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
@@ -391,10 +395,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
- if "processing_class" in sig.parameters:
+ if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
+
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py
index d7555261f..3427a0b86 100644
--- a/src/axolotl/core/trainers/base.py
+++ b/src/axolotl/core/trainers/base.py
@@ -49,6 +49,13 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
+REDUCTION_FNS = {
+ "mean": torch.mean,
+ "min": torch.min,
+ "max": torch.max,
+ "sum": torch.sum,
+}
+
class AxolotlTrainer(
PackingMixin,
@@ -89,7 +96,9 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
+ self._stored_metrics = defaultdict(
+ lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
+ )
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -585,9 +594,17 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[key] = torch.tensor(metrics).mean().item()
+
+ for key, metric_data in self._stored_metrics[train_eval].items():
+ values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
+ reduction_type = metric_data["reduction"]
+
+ fn = REDUCTION_FNS.get(reduction_type)
+ if fn is None:
+ raise NotImplementedError(
+ "Metric reduction must be one of [mean, min, max, sum]"
+ )
+ logs[key] = round(fn(values).item(), 4)
if is_main_process():
# Add memory usage
@@ -611,10 +628,27 @@ class AxolotlTrainer(
return super().log(logs, start_time)
def store_metrics(
- self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
+ self,
+ metrics: dict[str, float] | dict[str, tuple[int | float, str]],
+ train_eval: Literal["train", "eval"] = "train",
+ reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None:
+ """
+ Store metrics with specified reduction type.
+
+ Args:
+ metrics: Dictionary of metric names to values, or metric names to (value,
+ reduction_type) tuples.
+ train_eval: Whether this is for training or evaluation.
+ """
for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
+ if isinstance(value, tuple):
+ value, _reduction = value # type: ignore[assignment]
+ else:
+ value, _reduction = value, reduction
+
+ self._stored_metrics[train_eval][key]["values"].append(value)
+ self._stored_metrics[train_eval][key]["reduction"] = _reduction
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py
index 8edee18a3..c66bc01c6 100644
--- a/src/axolotl/integrations/base.py
+++ b/src/axolotl/integrations/base.py
@@ -142,7 +142,7 @@ class BasePlugin:
model: The loaded model.
"""
- def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
+ def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Returns a custom class for the trainer.
Args:
diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py
index 2217b2819..8ae8aab39 100644
--- a/src/axolotl/integrations/config.py
+++ b/src/axolotl/integrations/config.py
@@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
+ AxolotlInputConfig as AxolotlInputConfigBase,
)
-from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
def merge_input_args():
diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md
new file mode 100644
index 000000000..c27f33de1
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/README.md
@@ -0,0 +1,154 @@
+# Diffusion LM Training Plugin for Axolotl
+
+This plugin enables diffusion language model training using an approach inspired by
+LLaDA (Large Language Diffusion Models) within Axolotl.
+
+## Overview
+
+LLaDA is a diffusion-based approach to language model training that uses:
+- **Random token masking** during training instead of next-token prediction
+- **Bidirectional attention** to allow the model to attend to the full context
+- **Importance weighting** based on masking probabilities for stable training
+
+This approach can lead to more robust language models with better understanding of
+bidirectional context.
+
+## Installation
+
+The plugin is included with Axolotl. See our
+[installation docs](https://docs.axolotl.ai/docs/installation.html).
+
+## Quickstart
+
+Train with an example config (Llama‑3.2 1B):
+ - Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
+ - SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
+
+### Basic Configuration
+
+You can also modify your existing configs to enable / customize diffusion training.
+
+Add the following to your Axolotl config:
+
+```yaml
+# Enable diffusion LM training plugin
+plugins:
+ - axolotl.integrations.diffusion.DiffusionPlugin
+```
+
+And, configure the nested `diffusion` block (defaults shown):
+
+```yaml
+diffusion:
+ noise_schedule: linear # or "cosine"
+ min_mask_ratio: 0.1
+ max_mask_ratio: 0.9
+ num_diffusion_steps: 128
+ eps: 1e-3
+ importance_weighting: true
+
+ # Mask token (training auto-adds if missing, avoid pad/eos)
+ mask_token_str: "<|diffusion_mask|>"
+ # Or use an existing special token id (e.g., 128002 for Llama-3.x)
+ # mask_token_id: 128002
+
+ # Sample generation during training (optional)
+ generate_samples: true
+ generation_interval: 100
+ num_generation_samples: 3
+ generation_steps: 128
+ generation_temperature: 0.0
+ generation_max_length: 100
+```
+
+## Supported Models
+
+Any models that support 4D attention masks should work out of the box. If not, please
+create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
+[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
+
+## How It Works
+
+### Random Masking
+During training, tokens are randomly masked:
+- Sample timestep `t` uniformly from [0, 1]
+- Calculate masking probability: `p = (1 - eps) * t + eps`
+- Randomly mask tokens with probability `p`
+
+### Diffusion Loss
+
+Loss is computed only on masked tokens with (optional) importance weighting:
+
+```python
+loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
+```
+
+## Sample Generation
+
+When `diffusion.generate_samples: true`, the plugin generates samples during training:
+
+```
+Sample 1:
+ Original (45 tokens): The quick brown fox jumps over the lazy dog...
+ Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
+ Generated: The quick brown fox jumps over the lazy dog...
+```
+
+Samples are logged to console and wandb (if enabled).
+
+## Inference
+
+Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
+you trained with and run:
+
+```
+axolotl inference path/to/your-config.yaml
+```
+
+Optionally, pass `--gradio` to use a simple web interface.
+
+Interactive controls (prefix the prompt with commands):
+- `:complete N` → completion mode with N new masked tokens appended (default 64)
+- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
+
+Example session:
+
+```
+================================================================================
+Commands:
+:complete N -> completion mode with N tokens (default 64)
+:mask R -> random masking with ratio R (0.0–1.0)
+================================================================================
+Give me an instruction (Ctrl + D to submit):
+
+:mask 0.4 The quick brown fox jumps over the lazy dog
+
+Masked (40.0%):
+The [MASK] brown [MASK] jumps over the [MASK] dog
+
+Generated:
+The quick brown fox jumps over the loud dog
+```
+
+## Metrics and Monitoring
+
+The plugin adds (or modifies) several metrics to track diffusion training:
+
+- `train/loss`: Weighted diffusion loss
+- `train/accuracy`: Accuracy on masked tokens
+- `train/mask_ratio`: Average fraction of tokens masked
+- `train/num_masked_tokens`: Number of tokens masked
+- `train/avg_p_mask`: Average masking probability
+- `train/ce_loss`: Unweighted cross-entropy loss
+- `train/importance_weight_avg`: Average importance weight
+
+## Limitations
+
+- No flash attention support
+- No RL training support
+
+## References
+
+- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
+- [Axolotl Documentation](https://docs.axolotl.ai/)
+- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)
diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py
new file mode 100644
index 000000000..9e38cc5c1
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/__init__.py
@@ -0,0 +1,19 @@
+"""Diffusion LM training plugin init."""
+
+from .args import DiffusionArgs, DiffusionConfig
+from .callbacks import DiffusionGenerationCallback
+from .generation import generate
+from .plugin import DiffusionPlugin
+from .trainer import DiffusionTrainer
+from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
+
+__all__ = [
+ "DiffusionArgs",
+ "DiffusionPlugin",
+ "DiffusionTrainer",
+ "generate",
+ "resolve_mask_token_id",
+ "create_bidirectional_attention_mask",
+ "DiffusionGenerationCallback",
+ "DiffusionConfig",
+]
diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py
new file mode 100644
index 000000000..4f5bfe499
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/args.py
@@ -0,0 +1,95 @@
+"""Config args for diffusion LM training (nested under `diffusion:`)."""
+
+from __future__ import annotations
+
+from typing import Literal
+
+from pydantic import BaseModel, Field, model_validator
+
+
+class DiffusionConfig(BaseModel):
+ """Nested diffusion configuration available under the `diffusion` key."""
+
+ # Noise schedule config
+ noise_schedule: Literal["linear", "cosine"] = Field(
+ default="linear", description="Type of noise schedule for diffusion training"
+ )
+ min_mask_ratio: float = Field(
+ default=0.1,
+ ge=0.0,
+ le=1.0,
+ description="Minimum masking ratio for diffusion noise schedule",
+ )
+ max_mask_ratio: float = Field(
+ default=0.9,
+ ge=0.0,
+ le=1.0,
+ description="Maximum masking ratio for diffusion noise schedule",
+ )
+ num_diffusion_steps: int = Field(
+ default=128, ge=1, description="Number of diffusion timesteps"
+ )
+ eps: float = Field(
+ default=1e-3,
+ ge=0.0,
+ le=1.0,
+ description="Epsilon value for minimum masking probability in forward process",
+ )
+
+ # Training config
+ importance_weighting: bool = Field(
+ default=True,
+ description="Apply importance weighting to loss based on masking probability",
+ )
+ mask_token_id: int | None = Field(
+ default=None,
+ description=(
+ "Token ID to use for masking. Unset by default; can use one of the "
+ "tokenizer's special tokens here."
+ ),
+ )
+ mask_token_str: str | None = Field(
+ default=None,
+ description=(
+ "Token string to use as a mask. If `mask_token_id` is invalid or unset, "
+ "this token will be ensured to exist as an additional special token and "
+ "used. If absent, a default '<|diffusion_mask|>' will be added."
+ ),
+ )
+
+ # Sample generation config
+ generate_samples: bool = Field(
+ default=True, description="Enable sample generation during training"
+ )
+ generation_interval: int = Field(
+ default=100, ge=1, description="Generate samples every N steps"
+ )
+ num_generation_samples: int = Field(
+ default=3, ge=1, description="Number of samples to generate each time"
+ )
+ generation_steps: int = Field(
+ default=128, ge=1, description="Number of diffusion steps for generation"
+ )
+ generation_temperature: float = Field(
+ default=0.0,
+ ge=0.0,
+ description="Temperature for generation sampling (0.0 = deterministic)",
+ )
+ generation_max_length: int = Field(
+ default=100, ge=1, description="Maximum sequence length for generation"
+ )
+
+ @model_validator(mode="after")
+ def _validate_mask_ratios(self) -> "DiffusionConfig":
+ if self.min_mask_ratio > self.max_mask_ratio:
+ raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
+ return self
+
+
+class DiffusionArgs(BaseModel):
+ """Plugin entry that exposes the nested `diffusion` block to the core config."""
+
+ diffusion: DiffusionConfig = Field(
+ default_factory=DiffusionConfig,
+ description="Diffusion training configuration. Only nested block is supported.",
+ )
diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py
new file mode 100644
index 000000000..18a64023b
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/callbacks.py
@@ -0,0 +1,174 @@
+"""Callbacks for diffusion training."""
+
+import logging
+import sys
+
+import wandb
+from colorama import Fore, Style
+from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
+from transformers.training_args import TrainingArguments
+
+from .generation import generate_samples
+
+# Simpler logger for more readable sample generation
+logger = logging.getLogger(__name__)
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(logging.Formatter("%(message)s"))
+ logger.addHandler(handler)
+ logger.propagate = False
+logger.setLevel(logging.INFO)
+
+
+class DiffusionGenerationCallback(TrainerCallback):
+ """Callback for generating samples during diffusion training."""
+
+ def __init__(self, trainer):
+ self.trainer = trainer
+
+ def on_step_end(
+ self,
+ args: TrainingArguments,
+ state: TrainerState,
+ control: TrainerControl,
+ **kwargs,
+ ):
+ """Generate samples at specified intervals."""
+ if (
+ state.global_step > 0
+ and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
+ ):
+ if not self.trainer.state.is_world_process_zero:
+ return
+
+ # Use eval dataloader if available, otherwise use train dataloader
+ dataloader = None
+ try:
+ if getattr(self.trainer, "eval_dataset", None) is not None:
+ dataloader = self.trainer.get_eval_dataloader()
+ except Exception:
+ dataloader = None
+ if dataloader is None:
+ dataloader = self.trainer.get_train_dataloader()
+
+ # Generate samples
+ diffusion_cfg = self.trainer.cfg.diffusion
+ samples = generate_samples(
+ model=self.trainer.model,
+ tokenizer=self.trainer.processing_class,
+ dataloader=dataloader,
+ num_generation_samples=diffusion_cfg.num_generation_samples,
+ max_length=diffusion_cfg.generation_max_length,
+ num_diffusion_steps=diffusion_cfg.generation_steps,
+ temperature=diffusion_cfg.generation_temperature,
+ mask_token_id=diffusion_cfg.mask_token_id,
+ )
+
+ # Log samples
+ self._log_samples(samples, state.global_step)
+
+ def _log_samples(self, samples: list, step: int):
+ """Log generated samples."""
+ if not samples:
+ return
+
+ logger.info("=" * 60)
+ logger.info("GENERATED SAMPLES")
+ logger.info("=" * 60)
+
+ for i, sample_data in enumerate(samples, 1):
+ original = sample_data["original"]
+ masked = sample_data["masked"]
+ generated = sample_data["generated"]
+ mask_ratio = sample_data["mask_ratio"]
+ masked_tokens = sample_data["masked_tokens"]
+ total_tokens = sample_data["total_tokens"]
+
+ logger.info(f"\nSample {i}:")
+ logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
+ logger.info(
+ f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
+ f"{mask_ratio:.1%}): {masked}"
+ )
+
+ try:
+ gen_ids = sample_data.get("generated_ids")
+ orig_ids = sample_data.get("orig_ids")
+ masked_positions = set(sample_data.get("masked_positions") or [])
+ if isinstance(gen_ids, list) and isinstance(orig_ids, list):
+ styles: list[str] = []
+ for i, tid in enumerate(gen_ids):
+ if i in masked_positions:
+ if i < len(orig_ids) and tid == orig_ids[i]:
+ styles.append("green")
+ elif i < len(orig_ids):
+ styles.append("red")
+ else:
+ styles.append("normal")
+ else:
+ same = i < len(orig_ids) and tid == orig_ids[i]
+ styles.append("dim" if same else "normal")
+
+ spans: list[tuple[str, int, int]] = []
+ if gen_ids:
+ cur = styles[0]
+ start = 0
+ for i in range(1, len(gen_ids)):
+ s = styles[i]
+ if s != cur:
+ spans.append((cur, start, i))
+ cur, start = s, i
+ spans.append((cur, start, len(gen_ids)))
+
+ parts = []
+ for style_name, a, b in spans:
+ chunk_text = self.trainer.processing_class.decode(
+ gen_ids[a:b], skip_special_tokens=False
+ )
+ if style_name == "green":
+ parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
+ elif style_name == "red":
+ parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
+ else:
+ if style_name == "dim":
+ parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
+ else:
+ parts.append(chunk_text)
+ logger.info("\tGenerated:\n%s", "".join(parts))
+ else:
+ logger.info(f"\tGenerated: {generated}")
+ except Exception:
+ logger.info(f"\tGenerated: {generated}")
+
+ logger.info("=" * 60)
+
+ if self.trainer.cfg.use_wandb:
+ if wandb.run is not None:
+ wandb.log(
+ {
+ "generated_samples": wandb.Table(
+ columns=[
+ "step",
+ "original",
+ "masked",
+ "generated",
+ "mask_ratio",
+ "masked_tokens",
+ "total_tokens",
+ ],
+ data=[
+ [
+ step,
+ sample["original"],
+ sample["masked"],
+ sample["generated"],
+ f"{sample['mask_ratio']:.1%}",
+ sample["masked_tokens"],
+ sample["total_tokens"],
+ ]
+ for sample in samples
+ ],
+ )
+ },
+ step=step,
+ )
diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py
new file mode 100644
index 000000000..49e3cdfae
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/generation.py
@@ -0,0 +1,409 @@
+"""Sample generation utilities for diffusion training."""
+
+import re
+from typing import Any, List, Literal, Optional
+
+import torch
+
+from axolotl.utils.logging import get_logger
+
+from .utils import create_bidirectional_attention_mask
+
+LOG = get_logger(__name__)
+
+
+def generate_samples(
+ model: torch.nn.Module,
+ tokenizer: Any,
+ dataloader: Optional[Any] = None,
+ num_generation_samples: int = 3,
+ max_length: int = 100,
+ num_diffusion_steps: int = 128,
+ temperature: float = 0.0,
+ mask_token_id: int = 32000,
+ mode: Literal["random", "completion"] = "random",
+ completion_tokens: int = 0,
+ target_mask_ratio: Optional[float] = None,
+) -> List[dict]:
+ """
+ Generate text samples using the diffusion model by randomly masking sequences from
+ the given dataset and running the reverse diffusion process.
+
+ Args:
+ model: The wrapped or unwrapped model
+ tokenizer: Tokenizer for encoding/decoding
+ dataloader: Validation dataloader (for sampling sequences)
+ num_generation_samples: Number of samples to generate
+ max_length: Maximum length of sequences to use
+ num_diffusion_steps: Number of diffusion steps for generation
+ temperature: Temperature for sampling (0.0 = deterministic)
+ mask_token_id: Token ID used for masking
+
+ Returns:
+ List of dictionaries with original text, masked text, and generated text
+ """
+ if dataloader is None:
+ LOG.warning("No validation dataloader provided, cannot generate samples")
+ return []
+
+ unwrapped_model = model.module if hasattr(model, "module") else model
+ training = unwrapped_model.training
+ unwrapped_model.eval()
+
+ # Resolve device robustly (some modules don't expose `.device`)
+ device = getattr(unwrapped_model, "device", None)
+ if device is None:
+ try:
+ device = next(unwrapped_model.parameters()).device
+ except StopIteration:
+ device = torch.device("cpu")
+ generations = []
+
+ # Sample sequences from validation dataset
+ sampled_sequences = _sample_sequences_from_dataloader(
+ dataloader, num_generation_samples, max_length, device
+ )
+ LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
+
+ # Generate samples using reverse diffusion process
+ with torch.no_grad():
+ for sample in sampled_sequences:
+ if isinstance(sample, dict):
+ original_sequence = sample.get("input_ids")
+ labels_seq = sample.get("labels")
+ attn_seq = sample.get("attention_mask")
+ else:
+ original_sequence = sample
+ labels_seq = None
+ attn_seq = None
+ generation_result = generate(
+ unwrapped_model,
+ tokenizer,
+ original_sequence,
+ num_diffusion_steps,
+ temperature,
+ mask_token_id,
+ mode=mode,
+ completion_tokens=completion_tokens,
+ target_mask_ratio=target_mask_ratio,
+ labels=labels_seq,
+ attention_mask=attn_seq,
+ )
+ generations.append(generation_result)
+
+ # Restore prior training state
+ if training:
+ unwrapped_model.train()
+ else:
+ unwrapped_model.eval()
+
+ return generations
+
+
+def _sample_sequences_from_dataloader(
+ dataloader: Any, num_samples: int, max_length: int, device: torch.device
+) -> List[Any]:
+ """Sample sequences from validation dataloader."""
+ sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
+ sample_count = 0
+
+ # Skip a random number of batches (we could be more clever about this)
+ skip_batches = torch.randint(0, 10, (1,)).item()
+ batch_count = 0
+
+ for batch in dataloader:
+ # Skip some batches for variety
+ if batch_count < skip_batches:
+ batch_count += 1
+ continue
+
+ if sample_count >= num_samples:
+ break
+
+ batch_count += 1
+ input_ids = batch["input_ids"]
+ attention_mask = batch.get("attention_mask")
+ labels = batch.get("labels")
+
+ # Randomly sample from sequences in this batch
+ batch_indices = torch.randperm(input_ids.size(0)).tolist()
+
+ for i in batch_indices:
+ if sample_count >= num_samples:
+ break
+
+ # Get actual sequence length (non-padded)
+ if attention_mask is not None:
+ seq_len = attention_mask[i].sum().item()
+ else:
+ seq_len = input_ids.size(1)
+
+ if seq_len < 10:
+ continue
+
+ # Determine truncation length
+ max_total = min(seq_len, max_length)
+ if labels is not None:
+ labels_i = labels[i][:seq_len]
+ answer_mask = labels_i != -100
+ if not answer_mask.any():
+ # No answer tokens; skip for SFT masking
+ continue
+ first_ans_idx = int(
+ torch.nonzero(answer_mask, as_tuple=False)[0].item()
+ )
+ prompt_len = first_ans_idx
+ if prompt_len >= max_total:
+ # Prompt alone reaches cap; cannot include any answer
+ continue
+ remaining_answer = int(answer_mask[prompt_len:].sum().item())
+ allowed_answer = max_total - prompt_len
+ take_answer = min(remaining_answer, allowed_answer)
+ if take_answer <= 0:
+ continue
+ actual_length = prompt_len + take_answer
+ else:
+ actual_length = max_total
+
+ # Extract the (possibly truncated) sequence
+ sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
+ attn_seq = (
+ attention_mask[i][:actual_length].unsqueeze(0).to(device)
+ if attention_mask is not None
+ else None
+ )
+ if labels is not None:
+ labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
+ sampled_sequences.append(
+ {
+ "input_ids": sequence,
+ "labels": labels_seq,
+ "attention_mask": attn_seq,
+ }
+ )
+ else:
+ if attn_seq is not None:
+ sampled_sequences.append(
+ {"input_ids": sequence, "attention_mask": attn_seq}
+ )
+ else:
+ sampled_sequences.append(sequence)
+ sample_count += 1
+
+ return sampled_sequences
+
+
+def generate(
+ model: torch.nn.Module,
+ tokenizer: Any,
+ original_sequence: torch.Tensor,
+ num_diffusion_steps: int,
+ temperature: float,
+ mask_token_id: int,
+ *,
+ mode: Literal["random", "completion"] = "random",
+ completion_tokens: int = 0,
+ target_mask_ratio: Optional[float] = None,
+ labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> dict:
+ """Generate a single sample using reverse diffusion."""
+ # Get original text for comparison
+ original_text = tokenizer.decode(
+ original_sequence[0].cpu(), skip_special_tokens=True
+ )
+
+ # Build masked sequence
+ if (
+ labels is not None
+ and labels.numel() > 0
+ and (labels == -100).any()
+ and (labels != -100).any()
+ ):
+ # SFT case: completely mask all answer tokens (labels != -100)
+ total_tokens = original_sequence.size(1)
+ masked_indices = (labels != -100).to(dtype=torch.bool)
+ masked_sequence = original_sequence.clone()
+ masked_sequence[masked_indices] = mask_token_id
+ masked_tokens = int(masked_indices.sum().item())
+ mask_ratio = masked_tokens / max(int(total_tokens), 1)
+ elif mode == "completion" and completion_tokens > 0:
+ # Append mask tokens to the right for completion
+ total_tokens = original_sequence.size(1) + int(completion_tokens)
+ masked_indices = torch.zeros(
+ 1, total_tokens, dtype=torch.bool, device=original_sequence.device
+ )
+ masked_indices[0, -int(completion_tokens) :] = True
+
+ append = torch.full(
+ (1, int(completion_tokens)), mask_token_id, device=original_sequence.device
+ )
+ masked_sequence = torch.cat([original_sequence, append], dim=1)
+ masked_tokens = int(completion_tokens)
+ mask_ratio = masked_tokens / total_tokens
+ else:
+ # Apply random masking with optional fixed ratio
+ total_tokens = original_sequence.size(1)
+ if target_mask_ratio is None:
+ min_ratio, max_ratio = 0.1, 0.7
+ target_mask_ratio = (
+ torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
+ )
+ target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
+
+ # Create random mask indices
+ mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
+ masked_indices = torch.zeros(
+ 1, total_tokens, dtype=torch.bool, device=original_sequence.device
+ )
+ masked_indices[0, mask_positions] = True
+
+ # Create masked sequence
+ masked_sequence = original_sequence.clone()
+ masked_sequence[masked_indices] = mask_token_id
+
+ # Calculate actual mask ratio
+ masked_tokens = masked_indices.sum().item()
+ mask_ratio = masked_tokens / total_tokens
+
+ # Get masked text for comparison
+ masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
+ masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
+
+ # Run reverse diffusion process
+ sequence = masked_sequence.clone()
+ attention_mask = create_bidirectional_attention_mask(
+ sequence, attention_mask, sample_packing=attention_mask is not None
+ )
+ for step in range(num_diffusion_steps):
+ sequence = _diffusion_step(
+ model,
+ sequence,
+ step,
+ num_diffusion_steps,
+ temperature,
+ mask_token_id,
+ attention_mask,
+ )
+ generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
+
+ # Collect diagnostic info
+ final_ids = sequence[0].detach().cpu().tolist()
+ orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
+ if masked_indices is not None:
+ masked_positions = (
+ torch.where(masked_indices[0])[0].detach().cpu().tolist()
+ if masked_indices.ndim == 2
+ else []
+ )
+ else:
+ masked_positions = []
+
+ result = {
+ "original": original_text,
+ "masked": masked_text,
+ "generated": generated_text,
+ "mask_ratio": mask_ratio,
+ "masked_tokens": masked_tokens,
+ "total_tokens": total_tokens,
+ "generated_ids": final_ids,
+ "masked_positions": masked_positions,
+ "orig_ids": orig_ids_for_render,
+ "formatted": (
+ f"Original: '{original_text}' → Masked: '{masked_text}' "
+ f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
+ ),
+ }
+
+ return result
+
+
+def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
+ """Clean up masked text for display."""
+ mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
+ cleaned = masked_text.replace(mask_token_repr, "[MASK]")
+
+ # Remove literal special token strings
+ if hasattr(tokenizer, "special_tokens_map"):
+ for token_value in tokenizer.special_tokens_map.values():
+ if token_value and isinstance(token_value, str):
+ cleaned = cleaned.replace(token_value, "")
+
+ # Normalize whitespace but preserve newlines
+ cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
+ cleaned = re.sub(r"[ \t]+", " ", cleaned)
+ cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
+ return cleaned
+
+
+def _diffusion_step(
+ model: torch.nn.Module,
+ sequence: torch.Tensor,
+ step: int,
+ num_diffusion_steps: int,
+ temperature: float,
+ mask_token_id: int,
+ attention_mask: torch.Tensor | None = None,
+) -> torch.Tensor:
+ """Perform a single diffusion step with remasking."""
+ # Only process if there are masked tokens remaining
+ current_mask = sequence == mask_token_id
+ if not current_mask.any():
+ return sequence
+
+ # Create or use provided attention mask
+ if attention_mask is None:
+ batch_size, seq_len = sequence.shape
+ attention_mask = torch.ones(
+ batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
+ )
+
+ # Forward pass
+ outputs = model(input_ids=sequence, attention_mask=attention_mask)
+ logits = outputs.logits
+
+ # Only sample at currently masked positions
+ if current_mask.any():
+ masked_logits = logits[current_mask]
+
+ # Apply temperature scaling
+ if temperature > 0:
+ scaled_logits = masked_logits / temperature
+ else:
+ scaled_logits = masked_logits
+
+ # Suppress mask token in outputs
+ scaled_logits[:, mask_token_id] = -float("inf")
+
+ if temperature > 0:
+ # Add Gumbel noise for sampling
+ gumbel_noise = -torch.log(
+ -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
+ )
+ gumbel_logits = scaled_logits + gumbel_noise
+ predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
+ else:
+ predicted_tokens = torch.argmax(scaled_logits, dim=-1)
+
+ # Calculate probabilities for confidence scoring
+ probs = torch.softmax(scaled_logits, dim=-1)
+ predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
+
+ # Determine how many tokens to unmask this step
+ remaining_masked = current_mask.sum().item()
+ if step == num_diffusion_steps - 1:
+ num_to_unmask = remaining_masked
+ else:
+ unmask_ratio = 1.0 / (num_diffusion_steps - step)
+ num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
+
+ # Select highest confidence predictions to unmask
+ if num_to_unmask >= remaining_masked:
+ sequence[current_mask] = predicted_tokens
+ else:
+ _, top_indices = predicted_token_probs.topk(num_to_unmask)
+ mask_positions = torch.where(current_mask)[1]
+ positions_to_unmask = mask_positions[top_indices]
+ sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
+
+ return sequence
diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py
new file mode 100644
index 000000000..c31f48b03
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/plugin.py
@@ -0,0 +1,41 @@
+"""Diffusion LM training plugin for Axolotl."""
+
+from peft import PeftModel
+from transformers import PreTrainedModel
+
+from axolotl.integrations.base import BasePlugin
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.logging import get_logger
+
+from .trainer import DiffusionTrainer
+
+LOG = get_logger(__name__)
+
+
+class DiffusionPlugin(BasePlugin):
+ """
+ Plugin for diffusion language model training.
+
+ This plugin enables diffusion-based training using the LLaDA approach, which uses
+ random masking and bidirectional attention to train language models.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.cfg = None
+
+ def get_input_args(self) -> str:
+ """Returns the pydantic model for LLaDA plugin arguments."""
+ return "axolotl.integrations.diffusion.DiffusionArgs"
+
+ def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
+ """Perform actions after model is loaded."""
+ self.cfg = cfg
+
+ def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
+ """Return custom trainer class for diffusion training."""
+ return DiffusionTrainer
+
+ def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
+ """Configure trainer after creation."""
+ trainer.set_config(cfg)
diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py
new file mode 100644
index 000000000..42b2468f4
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/trainer.py
@@ -0,0 +1,301 @@
+"""Custom trainer for diffusion LM training."""
+
+from typing import Any, Literal
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from axolotl.core.trainers.base import AxolotlTrainer
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.logging import get_logger
+
+from .callbacks import DiffusionGenerationCallback
+from .utils import create_bidirectional_attention_mask
+
+LOG = get_logger(__name__)
+
+
+class DiffusionTrainer(AxolotlTrainer):
+ """Custom trainer for diffusion LM training that overrides loss computation."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.cfg = None
+ self._special_token_ids = None
+
+ def set_config(self, config: DictDefault):
+ """Set config for diffusion training."""
+ self.cfg = config
+ self._cache_special_token_ids()
+ self._resolve_mask_token_id()
+
+ token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
+ LOG.info(f"Diffusion: using mask_token_id={token_id}")
+
+ if getattr(config.diffusion, "generate_samples", True):
+ generation_callback = DiffusionGenerationCallback(self)
+ self.add_callback(generation_callback)
+
+ def _resolve_mask_token_id(self) -> None:
+ """Ensure mask_token_id is valid for the current tokenizer."""
+ from .utils import resolve_mask_token_id
+
+ tokenizer = getattr(self, "processing_class", None)
+ if tokenizer is None:
+ return
+
+ mid = resolve_mask_token_id(
+ tokenizer,
+ self.cfg,
+ allow_add=True,
+ model=getattr(self, "model", None),
+ )
+ try:
+ self.cfg.diffusion.mask_token_id = int(mid)
+ except Exception:
+ pass
+
+ def compute_loss(
+ self,
+ model: nn.Module,
+ inputs: dict[str, torch.Tensor],
+ return_outputs: bool = False,
+ num_items_in_batch: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
+ """Override compute_loss to use diffusion loss."""
+ input_ids = inputs.get("input_ids")
+ attention_mask = inputs.get("attention_mask")
+ labels = inputs.get("labels")
+
+ if input_ids is None:
+ raise ValueError("input_ids is required for diffusion training")
+
+ loss, outputs = self._compute_diffusion_loss(
+ model, input_ids, attention_mask, labels
+ )
+
+ if return_outputs:
+ return loss, outputs
+ return loss
+
+ def _cache_special_token_ids(self):
+ """Cache special token IDs to avoid repeated tokenizer access."""
+ if self.processing_class is None:
+ self._special_token_ids = set()
+ return
+
+ tokenizer = self.processing_class
+ special_tokens = set()
+
+ if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
+ special_tokens.add(tokenizer.bos_token_id)
+ if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
+ special_tokens.add(tokenizer.eos_token_id)
+ if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
+ special_tokens.add(tokenizer.pad_token_id)
+
+ self._special_token_ids = special_tokens
+
+ def _forward_process(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ labels: torch.Tensor | None = None,
+ eps: float = 1e-3,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Forward noising process. A timestep is sampled along the process, and tokens are
+ masked with probability determined by the configured noise schedule.
+
+ Args:
+ input_ids: Input token ids [batch_size, seq_len].
+ attention_mask: Attention mask [batch_size, seq_len].
+ labels: Labels for SFT training [batch_size, seq_len].
+ eps: Small epsilon value for minimum masking probability.
+
+ Returns:
+ noisy_batch: Input with some tokens masked.
+ masked_indices: Boolean mask indicating which tokens were masked.
+ p_mask: Masking probabilities for each token [batch_size, seq_len].
+ """
+ batch_size, seq_len = input_ids.shape
+ device = input_ids.device
+
+ # Sample random timesteps for each sample in batch
+ t = torch.rand(batch_size, device=device)
+ p_mask = (1 - eps) * t + eps # [batch_size]
+ p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
+
+ # Don't mask padding tokens if attention_mask is provided
+ if attention_mask is not None:
+ valid_mask = attention_mask.bool()
+ p_mask = p_mask * valid_mask.float()
+
+ # Create mask to exclude special tokens
+ special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
+ if self._special_token_ids:
+ for token_id in self._special_token_ids:
+ special_token_mask |= input_ids == token_id
+
+ # Create random mask based on p_mask
+ masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
+ masked_indices = masked_indices & ~special_token_mask
+ if attention_mask is not None:
+ masked_indices = masked_indices & attention_mask.bool()
+
+ # For SFT data, only mask answer tokens
+ if labels is not None:
+ answer_mask = labels != -100
+ masked_indices = masked_indices & answer_mask
+
+ # Create masked input
+ mask_token_id = int(self.cfg.diffusion.mask_token_id)
+ mask_value = torch.full_like(input_ids, mask_token_id)
+ noisy_batch = torch.where(masked_indices, mask_value, input_ids)
+
+ return noisy_batch, masked_indices, p_mask
+
+ def _compute_diffusion_loss(
+ self,
+ model: nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ labels: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | Any]:
+ """
+ Compute diffusion loss.
+
+ Args:
+ model: The model to compute loss for.
+ input_ids: Ground truth token ids [batch_size, seq_len].
+ attention_mask: Attention mask [batch_size, seq_len].
+ labels: Labels for SFT training [batch_size, seq_len].
+
+ Returns:
+ loss: Cross-entropy loss.
+ metrics: Dictionary of metrics.
+ """
+ # Short-circuit empty sequences
+ if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:
+ zero = torch.tensor(
+ 0.0,
+ device=(input_ids.device if input_ids is not None else None),
+ requires_grad=True,
+ )
+ return zero, {}
+
+ # If an attention_mask is provided and all positions are padding for every
+ # sample in this batch, skip the step.
+ if attention_mask is not None:
+ if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():
+ zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
+ return zero, {}
+
+ # Apply forward process
+ noisy_batch, masked_indices, p_mask = self._forward_process(
+ input_ids, attention_mask, labels, self.cfg.diffusion.eps
+ )
+
+ # Create bidirectional attention mask
+ bidirectional_mask = create_bidirectional_attention_mask(
+ input_ids, attention_mask, sample_packing=self.cfg.sample_packing
+ )
+
+ # Forward pass
+ outputs = model(
+ input_ids=noisy_batch.long(),
+ attention_mask=bidirectional_mask,
+ )
+ logits = outputs.logits
+
+ if masked_indices.sum() > 0:
+ valid_indices = torch.where(masked_indices)
+ batch_indices, seq_indices = valid_indices
+
+ masked_logits = logits[batch_indices, seq_indices]
+ masked_targets = input_ids[batch_indices, seq_indices]
+ masked_p_mask = p_mask[batch_indices, seq_indices]
+
+ # Compute cross-entropy loss without reduction
+ token_loss = F.cross_entropy(
+ masked_logits.float(), masked_targets, reduction="none"
+ )
+
+ if self.cfg.diffusion.importance_weighting:
+ masked_p_mask = masked_p_mask.float()
+ weighted_loss = token_loss / masked_p_mask
+ else:
+ weighted_loss = token_loss
+
+ if labels is not None:
+ # For SFT data: normalize by answer token count per sample
+ answer_mask = labels != -100
+ answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
+
+ # Get batch indices for masked tokens
+ masked_batch_indices = batch_indices
+
+ # Sum losses per sample and divide by answer length
+ batch_size = input_ids.shape[0]
+ loss_per_sample = torch.zeros(batch_size, device=input_ids.device)
+ for i in range(batch_size):
+ sample_mask = masked_batch_indices == i
+ if sample_mask.sum() > 0:
+ sample_loss = weighted_loss[sample_mask].sum()
+ denom = answer_lengths[i].clamp(min=1.0)
+ loss_per_sample[i] = sample_loss / denom
+
+ loss = loss_per_sample.mean()
+ else:
+ # Non-SFT: when importance weighting is enabled, use unbiased estimator
+ # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
+ # for stable scaling across varying mask ratios.
+ if self.cfg.diffusion.importance_weighting:
+ loss = weighted_loss.sum() / (
+ input_ids.shape[0] * input_ids.shape[1]
+ )
+ else:
+ loss = weighted_loss.mean()
+
+ ce_loss = token_loss.mean()
+
+ # Compute accuracy on masked tokens
+ with torch.no_grad():
+ pred_tokens = masked_logits.argmax(dim=-1)
+ accuracy = (pred_tokens == masked_targets).float().mean()
+ else:
+ loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
+ accuracy = torch.tensor(0.0, device=input_ids.device)
+ ce_loss = torch.tensor(0.0, device=input_ids.device)
+ masked_p_mask = torch.tensor(1.0, device=input_ids.device)
+
+ avg_p_mask = (
+ p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0
+ )
+ metrics = {
+ "loss": loss.item(),
+ "accuracy": accuracy.item(),
+ "mask_ratio": masked_indices.float().mean().item(),
+ "num_masked_tokens": (masked_indices.sum().item(), "sum"),
+ "avg_p_mask": avg_p_mask,
+ "ce_loss": ce_loss.item(),
+ }
+
+ # If doing SFT training, log answer-specific metrics
+ if self.cfg.datasets is not None:
+ with torch.no_grad():
+ answer_mask = labels != -100
+ answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
+ total_answer_tokens = answer_mask.sum().item() # type: ignore
+ total_tokens = labels.numel() # type: ignore
+ metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
+ metrics["avg_answer_length"] = answer_lengths.mean().item()
+
+ if self.cfg.diffusion.importance_weighting:
+ metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
+
+ train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
+ self.store_metrics(metrics, train_eval=train_eval)
+
+ return loss, outputs
diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py
new file mode 100644
index 000000000..47abf6fec
--- /dev/null
+++ b/src/axolotl/integrations/diffusion/utils.py
@@ -0,0 +1,159 @@
+"""Shared utilities for diffusion integration."""
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import torch
+
+from axolotl.utils.dict import DictDefault
+
+
+def resolve_mask_token_id(
+ tokenizer: Any,
+ cfg: DictDefault,
+ *,
+ allow_add: bool,
+ model: Any | None = None,
+ default_token: str = "<|diffusion_mask|>",
+) -> int:
+ """Resolve mask token id. Training may add a new special token; inference won't."""
+ # Determine vocab size if available
+ vocab_size = None
+ if tokenizer is not None:
+ if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None:
+ try:
+ vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type]
+ except Exception:
+ vocab_size = None
+ elif hasattr(tokenizer, "__len__"):
+ try:
+ vocab_size = int(len(tokenizer))
+ except Exception:
+ vocab_size = None
+
+ # Use explicit id from config if provided
+ diffusion_cfg = getattr(cfg, "diffusion", None)
+ # Fallback to top-level attr names only if nested missing (shouldn't happen)
+ cfg_id = (
+ getattr(diffusion_cfg, "mask_token_id", None)
+ if diffusion_cfg is not None
+ else getattr(cfg, "diffusion_mask_token_id", None)
+ )
+ if isinstance(cfg_id, int) and cfg_id >= 0:
+ if vocab_size is None or cfg_id < vocab_size:
+ return int(cfg_id)
+
+ def _existing_special_token_id(token_str: str | None) -> int | None:
+ """Attempt to resolve an existing special token string to a real ID."""
+ if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"):
+ return None
+ try:
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
+ except Exception:
+ return None
+
+ if not isinstance(token_id, int) or token_id < 0:
+ return None
+
+ # Ensure it's registered as special and not UNK, and within vocab
+ unk_id = getattr(tokenizer, "unk_token_id", None)
+ specials = set(getattr(tokenizer, "all_special_tokens", []) or [])
+ addl = set(getattr(tokenizer, "additional_special_tokens", []) or [])
+ is_special = token_str in specials or token_str in addl
+ in_vocab = vocab_size is None or token_id < vocab_size
+ if (
+ (unk_id is not None and token_id == unk_id)
+ or not is_special
+ or not in_vocab
+ ):
+ return None
+ return token_id
+
+ # Try mask token string if provided
+ token_str = (
+ getattr(diffusion_cfg, "mask_token_str", None)
+ if diffusion_cfg is not None
+ else getattr(cfg, "diffusion_mask_token_str", None)
+ )
+ for candidate in (token_str, default_token):
+ token_id = _existing_special_token_id(candidate)
+ if isinstance(token_id, int):
+ try:
+ if diffusion_cfg is None:
+ cfg.diffusion_mask_token_id = int(token_id) # legacy fallback
+ else:
+ diffusion_cfg.mask_token_id = int(token_id)
+ except Exception:
+ pass
+ return int(token_id)
+
+ # Optionally add and return a dedicated special token during training
+ if allow_add and hasattr(tokenizer, "add_special_tokens"):
+ token_to_add = token_str or default_token
+ try:
+ tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]})
+
+ # Resize embeddings if possible
+ if (
+ model is not None
+ and hasattr(tokenizer, "__len__")
+ and hasattr(model, "resize_token_embeddings")
+ ):
+ try:
+ model.resize_token_embeddings(len(tokenizer))
+ except Exception:
+ pass
+ new_id = tokenizer.convert_tokens_to_ids(token_to_add)
+ if isinstance(new_id, int) and new_id >= 0:
+ try:
+ if diffusion_cfg is None:
+ cfg.diffusion_mask_token_id = int(new_id) # legacy fallback
+ else:
+ diffusion_cfg.mask_token_id = int(new_id)
+ except Exception:
+ pass
+ return int(new_id)
+ except Exception:
+ pass
+
+ # Fallback to unk or 0 (do not update cfg)
+ fallback = getattr(tokenizer, "unk_token_id", 0) or 0
+ return int(fallback)
+
+
+def create_bidirectional_attention_mask(
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ sample_packing: bool = False,
+) -> torch.Tensor:
+ """
+ Create bidirectional attention mask to override default causal masking.
+ Handles sample-packed sequences where different samples are identified
+ by different attention mask values.
+
+ Args:
+ input_ids: Input token ids [batch_size, seq_len]
+ attention_mask: Attention mask [batch_size, seq_len]
+ sample_packing: Whether sample packing is enabled
+
+ Returns:
+ bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
+ """
+ batch_size, seq_len = input_ids.shape
+ device = input_ids.device
+
+ if attention_mask is None or not sample_packing:
+ return torch.ones(
+ batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
+ )
+
+ # Handle sample packing: tokens can only attend within their sample
+ mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
+ mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
+
+ # Tokens can attend to each other if they have the same non-zero sample ID
+ bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
+
+ # Add head dimension: [batch_size, 1, seq_len, seq_len]
+ return bidirectional_mask.unsqueeze(1)
diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py
index 989b34aee..bcde4bf96 100644
--- a/src/axolotl/loaders/adapter.py
+++ b/src/axolotl/loaders/adapter.py
@@ -14,6 +14,7 @@ from peft import (
PeftConfig,
PeftMixedModel,
PeftModel,
+ TaskType,
get_peft_model,
)
from transformers import PreTrainedModel
@@ -101,6 +102,15 @@ def load_lora(
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
+ # Determine the correct PEFT task type
+ model_cls = type(model).__name__
+ if "SequenceClassification" in model_cls:
+ task_type = TaskType.SEQ_CLS
+ elif "TokenClassification" in model_cls:
+ task_type = TaskType.TOKEN_CLS
+ else:
+ task_type = TaskType.CAUSAL_LM
+
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
@@ -112,7 +122,7 @@ def load_lora(
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
- task_type="CAUSAL_LM",
+ task_type=task_type,
**lora_config_kwargs,
)
diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py
index a9507d685..f438d6b61 100644
--- a/src/axolotl/loaders/model.py
+++ b/src/axolotl/loaders/model.py
@@ -673,6 +673,33 @@ class ModelLoader:
return hf_ds_cfg
+ def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel:
+ """
+ Load model with random initialization using from_config.
+
+ Uses the selected loader when provided; otherwise falls back to the auto loader.
+ """
+ loader = model_loader_class or self.auto_model_loader
+ if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
+ model = loader.from_config(
+ config=self.model_config,
+ trust_remote_code=self.cfg.trust_remote_code or False,
+ )
+ else:
+ model = loader(config=self.model_config)
+
+ return model
+
+ def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
+ """Load model from pretrained weights."""
+ loader = model_loader_class or self.auto_model_loader
+ kwargs = {
+ "config": self.model_config,
+ "trust_remote_code": self.cfg.trust_remote_code or False,
+ **self.model_kwargs,
+ }
+ return loader.from_pretrained(self.base_model, **kwargs)
+
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
@@ -687,7 +714,8 @@ class ModelLoader:
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
- # Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
+ # Don't delete device_map for QLoRA + FSDP - it was set correctly in
+ # _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
@@ -716,6 +744,11 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading
)
):
+ if self.cfg.reinit_weights:
+ LOG.warning(
+ "reinit_weights is not supported with sharded quantized loading. "
+ "Loading from pretrained weights instead."
+ )
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
@@ -731,33 +764,12 @@ class ModelLoader:
quantization_config=quantization_config,
)
skip_move_to_device = True
- elif (
- self.model_config.model_type in ["llama", "llama4"]
- and not self.cfg.trust_remote_code
- and not self.cfg.gptq
- ):
- # Please don't remove underscore binding without reading the fn docstring.
- _ = self._configure_zero3_memory_efficient_loading()
-
- # Load model with random initialization if specified
- if self.cfg.random_init_weights:
- # AutoModel classes support the from_config method
- if self.auto_model_loader in [
- AutoModelForCausalLM,
- AutoModelForVision2Seq,
- ]:
- self.model = self.auto_model_loader.from_config(
- config=self.model_config,
- )
- else:
- self.model = self.auto_model_loader(config=self.model_config)
- else:
- self.model = self.auto_model_loader.from_pretrained(
- self.base_model,
- config=self.model_config,
- **self.model_kwargs,
- )
elif self.model_type == "MambaLMHeadModel":
+ if self.cfg.reinit_weights:
+ LOG.warning(
+ "reinit_weights is not supported with MambaLMHeadModel. "
+ "Loading from pretrained weights instead."
+ )
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss()
@@ -770,41 +782,27 @@ class ModelLoader:
self.base_model,
**self.model_kwargs,
)
- elif (
- self.model_type
- and self.model_type != "AutoModelForCausalLM"
- and not self.cfg.trust_remote_code
- ):
- if self.cfg.gptq:
- self.model = self.auto_model_loader.from_pretrained(
- self.base_model,
- config=self.model_config,
- trust_remote_code=self.cfg.trust_remote_code or False,
- **self.model_kwargs,
- )
- else:
- self.model = getattr(transformers, self.model_type).from_pretrained(
- self.base_model,
- config=self.model_config,
- trust_remote_code=self.cfg.trust_remote_code or False,
- **self.model_kwargs,
- )
- elif self.cfg.gptq:
- self.model = self.auto_model_loader.from_pretrained(
- self.base_model,
- config=self.model_config,
- trust_remote_code=self.cfg.trust_remote_code or False,
- **self.model_kwargs,
- )
else:
- # Please don't remove underscore binding without reading the fn docstring.
+ # Please don't remove underscore binding without reading the fn docstring
_ = self._configure_zero3_memory_efficient_loading()
- self.model = self.auto_model_loader.from_pretrained(
- self.base_model,
- config=self.model_config,
- trust_remote_code=self.cfg.trust_remote_code or False,
- **self.model_kwargs,
- )
+
+ if (
+ self.model_type
+ and self.model_type != "AutoModelForCausalLM"
+ and not self.cfg.trust_remote_code
+ and not self.cfg.gptq
+ ):
+ # Use model type from transformers
+ model_loader_class = getattr(transformers, self.model_type)
+ else:
+ # Use auto model loader (handles gptq and default cases)
+ model_loader_class = self.auto_model_loader
+
+ if self.cfg.reinit_weights:
+ self.model = self._load_model_from_config(model_loader_class)
+ else:
+ self.model = self._load_model_from_pretrained(model_loader_class)
+
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py
index 044c278a3..a5a630cb5 100644
--- a/src/axolotl/loaders/patch_manager.py
+++ b/src/axolotl/loaders/patch_manager.py
@@ -3,8 +3,8 @@
Applies pre- and post-model load patches for various fixes and optimizations.
"""
-import os
import importlib.util
+import os
from functools import cached_property
import addict
@@ -468,9 +468,10 @@ class PatchManager:
def _apply_patch_deepspeed_zero3(self):
try:
- from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+ from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
+
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py
index 3b38a33b7..d8ba02cb2 100644
--- a/src/axolotl/monkeypatch/accelerate/fsdp2.py
+++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py
@@ -160,9 +160,11 @@ def get_state_dict(self, model, unwrap=True):
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
- from torch.distributed.fsdp import FullStateDictConfig
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- from torch.distributed.fsdp import StateDictType
+ from torch.distributed.fsdp import (
+ FullStateDictConfig,
+ FullyShardedDataParallel as FSDP,
+ StateDictType,
+ )
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py
index 65ccad533..678f65bee 100644
--- a/src/axolotl/monkeypatch/attention/flex_attn.py
+++ b/src/axolotl/monkeypatch/attention/flex_attn.py
@@ -1,11 +1,12 @@
"""Flex attention monkey patch"""
import sys
-from packaging import version
import torch
import transformers
+from packaging import version
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
+
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
diff --git a/src/axolotl/monkeypatch/deepspeed_utils.py b/src/axolotl/monkeypatch/deepspeed_utils.py
index 6740f556b..d7e69e112 100644
--- a/src/axolotl/monkeypatch/deepspeed_utils.py
+++ b/src/axolotl/monkeypatch/deepspeed_utils.py
@@ -1,5 +1,6 @@
import importlib
import importlib.util
+
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py
index f40fe6687..7a2bbd6f9 100644
--- a/src/axolotl/utils/config/__init__.py
+++ b/src/axolotl/utils/config/__init__.py
@@ -17,8 +17,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
+ AxolotlInputConfig as AxolotlInputConfigBase,
)
-from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = get_logger(__name__)
diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py
index 788f13638..8b9e4e91d 100644
--- a/src/axolotl/utils/data/__init__.py
+++ b/src/axolotl/utils/data/__init__.py
@@ -1,14 +1,14 @@
"""Init for `axolotl.utils.data` module."""
-from axolotl.utils.data.streaming import (
- encode_streaming,
- wrap_streaming_dataset,
-)
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
get_dataset_wrapper,
prepare_datasets,
)
+from axolotl.utils.data.streaming import (
+ encode_streaming,
+ wrap_streaming_dataset,
+)
from axolotl.utils.data.utils import md5
__all__ = [
diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py
index 28732e01d..ba5aec2d6 100644
--- a/src/axolotl/utils/data/sft.py
+++ b/src/axolotl/utils/data/sft.py
@@ -16,7 +16,6 @@ from transformers import PreTrainedTokenizer, ProcessorMixin
from axolotl.prompters import Prompter
from axolotl.utils.data.lock import FileLockLoader
-from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
@@ -27,6 +26,7 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
+from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py
index 751f7e253..192aca4e1 100644
--- a/src/axolotl/utils/environment.py
+++ b/src/axolotl/utils/environment.py
@@ -6,8 +6,6 @@ from importlib.metadata import version
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
-)
-from accelerate.utils.environment import (
get_gpu_info,
)
from packaging.version import Version, parse
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index e4c1fdf29..d612ec8a5 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -106,6 +106,12 @@ class AxolotlInputConfig(
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
},
)
+ reinit_weights: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Reinitialize model weights randomly instead of loading pretrained weights"
+ },
+ )
trainer_cls: str | None = Field(
default=None,
diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py
index 49add8081..64018ca48 100644
--- a/src/axolotl/utils/schemas/validation.py
+++ b/src/axolotl/utils/schemas/validation.py
@@ -14,7 +14,6 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
-
LOG = get_logger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py
new file mode 100644
index 000000000..cc3d8070b
--- /dev/null
+++ b/tests/e2e/test_diffusion.py
@@ -0,0 +1,139 @@
+"""E2E smoke test for diffusion training plugin."""
+
+from axolotl.common.datasets import load_datasets
+from axolotl.train import train
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+from tests.e2e.utils import check_model_output_exists
+
+
+class TestDiffusion:
+ """Test case for diffusion training plugin."""
+
+ def test_diffusion_smoke_test(self, temp_dir):
+ """
+ Smoke test for diffusion training to ensure the plugin loads and trains without
+ error.
+ """
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "trust_remote_code": True,
+ "sequence_len": 256,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.0001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": False,
+ "logging_steps": 1,
+ "eval_steps": 3,
+ # Diffusion-specific config
+ "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
+ "diffusion": {
+ # sample generation
+ "generate_samples": True,
+ "generation_interval": 1,
+ "num_generation_samples": 1,
+ "generation_steps": 2,
+ "generation_max_length": 32,
+ "generation_temperature": 0.0,
+ # training-specific
+ "mask_token_id": 16,
+ "eps": 1e-3,
+ "importance_weighting": False,
+ },
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ check_model_output_exists(temp_dir, cfg)
+
+ def test_diffusion_sft_labels(self, temp_dir):
+ """Test that diffusion training properly handles SFT data with labels."""
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "trust_remote_code": True,
+ "sequence_len": 256,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.0001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": False,
+ "logging_steps": 1,
+ "eval_steps": 2,
+ # Diffusion-specific config
+ "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
+ "diffusion": {
+ # sample generation
+ "generate_samples": True,
+ "generation_interval": 1,
+ "num_generation_samples": 1,
+ "generation_steps": 2,
+ "generation_max_length": 32,
+ "generation_temperature": 0.0,
+ # training-specific
+ "mask_token_id": 16,
+ "eps": 1e-3,
+ "importance_weighting": True,
+ },
+ # Ensure we have proper SFT labels
+ "train_on_inputs": False,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ # Verify that the dataset has labels
+ sample = dataset_meta.train_dataset[0]
+ assert "labels" in sample, "SFT dataset should have labels"
+
+ # Check that some labels are -100 (prompt tokens)
+ labels = sample["labels"]
+ if hasattr(labels, "tolist"):
+ labels = labels.tolist()
+ assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ check_model_output_exists(temp_dir, cfg)
diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py
new file mode 100644
index 000000000..141d8d150
--- /dev/null
+++ b/tests/integrations/test_diffusion.py
@@ -0,0 +1,274 @@
+"""Tests for diffusion trainer integration."""
+
+# pylint: disable=redefined-outer-name,protected-access
+
+from unittest.mock import Mock
+
+import pytest
+import torch
+
+from axolotl.integrations.diffusion import DiffusionTrainer
+from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture
+def mock_tokenizer():
+ """Create a mock tokenizer."""
+ tokenizer = Mock()
+ tokenizer.bos_token_id = 1
+ tokenizer.eos_token_id = 2
+ tokenizer.pad_token_id = 0
+ return tokenizer
+
+
+@pytest.fixture
+def diffusion_config():
+ """Create a diffusion config."""
+ return DictDefault(
+ {
+ "diffusion": {
+ "mask_token_id": 32000,
+ "eps": 1e-3,
+ "importance_weighting": False,
+ },
+ "sample_packing": False,
+ }
+ )
+
+
+@pytest.fixture
+def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
+ """Create a diffusion trainer instance for testing methods directly."""
+ # Create a minimal trainer instance just for testing methods
+ trainer = object.__new__(DiffusionTrainer) # Bypass __init__
+ trainer.cfg = diffusion_config
+ trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
+ trainer.processing_class = mock_tokenizer
+ trainer.store_metrics = Mock() # Mock metrics storage
+ return trainer
+
+
+class TestDiffusionTrainer:
+ """Test the DiffusionTrainer class."""
+
+ def test_forward_process_basic(self, diffusion_trainer_instance):
+ """Test basic forward process without labels."""
+ input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
+
+ noisy_batch, masked_indices, p_mask = (
+ diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
+ )
+
+ # Check shapes
+ assert noisy_batch.shape == input_ids.shape
+ assert masked_indices.shape == input_ids.shape
+ assert p_mask.shape == input_ids.shape
+
+ # Check that special tokens are not masked
+ special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
+ assert not masked_indices[special_token_positions].any()
+
+ # Check that mask token is applied
+ mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
+ masked_positions = masked_indices
+ if masked_positions.any():
+ assert (noisy_batch[masked_positions] == mask_token_id).all()
+
+ def test_forward_process_with_labels(self, diffusion_trainer_instance):
+ """Test forward process with SFT labels."""
+ input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
+ labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
+
+ noisy_batch, masked_indices, p_mask = (
+ diffusion_trainer_instance._forward_process(
+ input_ids, labels=labels, eps=0.1
+ )
+ )
+
+ # Check shapes
+ assert noisy_batch.shape == input_ids.shape
+ assert masked_indices.shape == input_ids.shape
+ assert p_mask.shape == input_ids.shape
+
+ # Check that only answer tokens can be masked (where labels != -100)
+ non_answer_mask = labels == -100
+
+ # No masking should occur on non-answer tokens
+ assert not masked_indices[non_answer_mask].any()
+
+ # p_mask should be the same for all positions (sampled timestep),
+ # but masking is only applied to answer tokens
+ assert p_mask.shape == input_ids.shape
+ # Verify that masked_indices respects the answer mask
+ assert not masked_indices[non_answer_mask].any()
+
+ def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
+ """Test forward process with attention mask."""
+ input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
+ attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
+
+ _, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
+ input_ids, attention_mask=attention_mask, eps=0.1
+ )
+
+ # Check that padding tokens are not masked
+ padding_positions = attention_mask == 0
+ assert not masked_indices[padding_positions].any()
+ assert (p_mask[padding_positions] == 0).all()
+
+ def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
+ """Test bidirectional attention mask without sample packing."""
+ input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
+
+ mask = create_bidirectional_attention_mask(input_ids)
+
+ # Should be all-to-all attention
+ expected_shape = (1, 1, 4, 4)
+ assert mask.shape == expected_shape
+ assert mask.all()
+
+ def test_bidirectional_attention_mask_with_packing(
+ self, diffusion_trainer_instance
+ ):
+ """Test bidirectional attention mask with sample packing."""
+ diffusion_trainer_instance.cfg.sample_packing = True
+ input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
+ # Sample IDs: first sample (1), second sample (2)
+ attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
+
+ mask = create_bidirectional_attention_mask(
+ input_ids, attention_mask, sample_packing=True
+ )
+
+ # Check that tokens within same sample can attend to each other
+ # but not across samples
+ assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
+ assert mask[0, 0, 1, 2].item()
+ assert not mask[0, 0, 0, 3].item() # Can't attend across samples
+ assert not mask[0, 0, 2, 4].item()
+ assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
+
+ def test_compute_loss_basic(self, diffusion_trainer_instance):
+ """Test basic loss computation."""
+ # Mock model that returns logits
+ mock_model = Mock()
+ mock_outputs = Mock()
+ vocab_size = 1000
+ seq_len = 5
+ mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
+ mock_model.return_value = mock_outputs
+ mock_model.training = True
+
+ input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
+
+ loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
+ mock_model, input_ids
+ )
+
+ # Check that loss is computed
+ assert isinstance(loss, torch.Tensor)
+ assert loss.requires_grad
+ assert outputs == mock_outputs
+
+ # Check that metrics were stored
+ diffusion_trainer_instance.store_metrics.assert_called_once()
+
+ def test_compute_loss_sft(self, diffusion_trainer_instance):
+ """Test loss computation with SFT labels."""
+ # Mock model
+ mock_model = Mock()
+ mock_outputs = Mock()
+ vocab_size = 1000
+ seq_len = 5
+ mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
+ mock_model.return_value = mock_outputs
+ mock_model.training = True
+ diffusion_trainer_instance.cfg.datasets = Mock()
+
+ input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
+ labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
+
+ loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
+ mock_model, input_ids, labels=labels
+ )
+
+ # Check that loss is computed
+ assert isinstance(loss, torch.Tensor)
+ assert loss.requires_grad
+
+ # Check that SFT metrics were added
+ call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
+ assert "answer_ratio" in call_args
+ assert "avg_answer_length" in call_args
+
+ def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
+ """Test loss computation when no tokens are masked."""
+ # Mock model
+ mock_model = Mock()
+ mock_outputs = Mock()
+ vocab_size = 1000
+ seq_len = 3
+ mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
+ mock_model.return_value = mock_outputs
+ mock_model.training = True
+
+ # Only special tokens (which won't be masked)
+ input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
+
+ loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
+ mock_model, input_ids
+ )
+
+ # Loss should be zero when no tokens are masked
+ assert loss.item() == 0.0
+ assert loss.requires_grad
+
+ def test_cache_special_token_ids(self, mock_tokenizer):
+ """Test caching of special token IDs."""
+ trainer = object.__new__(DiffusionTrainer)
+ trainer.processing_class = mock_tokenizer
+ trainer._cache_special_token_ids()
+ assert trainer._special_token_ids == {0, 1, 2}
+
+ def test_cache_special_token_ids_no_tokenizer(self):
+ """Test caching when no tokenizer is available."""
+ trainer = object.__new__(DiffusionTrainer)
+ trainer.processing_class = None
+ trainer._cache_special_token_ids()
+
+ assert trainer._special_token_ids == set()
+
+ def test_main_compute_loss_interface(self, diffusion_trainer_instance):
+ """Test the main compute_loss interface."""
+ # Mock model
+ mock_model = Mock()
+ mock_outputs = Mock()
+ mock_outputs.logits = torch.randn(1, 5, 1000)
+ mock_model.return_value = mock_outputs
+ mock_model.training = True
+
+ inputs = {
+ "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
+ "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
+ "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
+ }
+
+ # Test without return_outputs
+ loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test with return_outputs
+ loss, outputs = diffusion_trainer_instance.compute_loss(
+ mock_model, inputs, return_outputs=True
+ )
+ assert isinstance(loss, torch.Tensor)
+ assert outputs == mock_outputs
+
+ def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
+ """Test that missing input_ids raises ValueError."""
+ mock_model = Mock()
+ inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
+
+ with pytest.raises(ValueError, match="input_ids is required"):
+ diffusion_trainer_instance.compute_loss(mock_model, inputs)
diff --git a/tests/integrations/test_diffusion_callback.py b/tests/integrations/test_diffusion_callback.py
new file mode 100644
index 000000000..3e8785fe0
--- /dev/null
+++ b/tests/integrations/test_diffusion_callback.py
@@ -0,0 +1,92 @@
+"""Tests for diffusion generation callback dataloader selection and triggering."""
+
+from types import SimpleNamespace
+from unittest.mock import Mock
+
+import pytest
+
+from axolotl.integrations.diffusion import DiffusionGenerationCallback
+
+
+class DummyTrainer:
+ """Minimal trainer double with required attributes/methods for the callback."""
+
+ def __init__(self, use_eval: bool):
+ # Config used by callback
+ self.cfg = SimpleNamespace(
+ diffusion=SimpleNamespace(
+ generation_interval=1,
+ num_generation_samples=1,
+ generation_max_length=32,
+ generation_steps=4,
+ generation_temperature=0.0,
+ mask_token_id=16,
+ ),
+ use_wandb=False,
+ )
+
+ # Model/tokenizer are passed through to generate_samples; not used here
+ self.model = Mock()
+ self.processing_class = Mock()
+
+ # Datasets and loaders
+ self.eval_dataset = object() if use_eval else None
+ self._train_loader = object()
+ self._eval_loader = object()
+
+ # State for world process check
+ self.state = SimpleNamespace(is_world_process_zero=True)
+
+ # Track which loader was requested
+ self.requested: list[str] = []
+
+ def get_train_dataloader(self):
+ self.requested.append("train")
+ return self._train_loader
+
+ def get_eval_dataloader(self):
+ self.requested.append("eval")
+ return self._eval_loader
+
+
+@pytest.mark.parametrize("use_eval", [False, True])
+def test_callback_uses_correct_dataloader(monkeypatch, use_eval):
+ trainer = DummyTrainer(use_eval=use_eval)
+ callback = DiffusionGenerationCallback(trainer)
+
+ captured = {}
+
+ # Patch generate_samples in the callback module's namespace
+ def fake_generate_samples(**kwargs):
+ captured["dataloader"] = kwargs.get("dataloader")
+ # Return one dummy sample to exercise logging path
+ return [
+ {
+ "original": "o",
+ "masked": "m",
+ "generated": "g",
+ "mask_ratio": 0.5,
+ "masked_tokens": 1,
+ "total_tokens": 2,
+ }
+ ]
+
+ monkeypatch.setattr(
+ "axolotl.integrations.diffusion.callbacks.generate_samples",
+ fake_generate_samples,
+ )
+
+ # Trigger at step 1 (interval=1)
+ args = SimpleNamespace()
+ state = SimpleNamespace(global_step=1)
+ control = SimpleNamespace()
+
+ callback.on_step_end(args=args, state=state, control=control)
+
+ # Assert the expected dataloader path was used
+ if use_eval:
+ assert trainer.requested[0] == "eval"
+ assert captured["dataloader"] is trainer._eval_loader
+ else:
+ assert trainer.requested[0] == "train"
+ assert captured["dataloader"] is trainer._train_loader
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 54acbb5e4..2c1f9f936 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -5,12 +5,12 @@ from unittest.mock import Mock, patch
from datasets import IterableDataset
-from axolotl.utils.dict import DictDefault
+from axolotl.utils.config import validate_config
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
-from axolotl.utils.config import validate_config
+from axolotl.utils.dict import DictDefault
class TestStreamingConfig(unittest.TestCase):
From 9406c0c488277ef9d7152568b9fda50600c4221e Mon Sep 17 00:00:00 2001
From: salman
+ A Free and Open Source LLM Fine-tuning Framework
+
@@ -50,20 +53,21 @@
## ✨ Overview
-Axolotl is a tool designed to streamline post-training for various AI models.
+Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Features:
-- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
-- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
-- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
+- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
+- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
+- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
+- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
-## 🚀 Quick Start
+## 🚀 Quick Start - LLM Fine-tuning in Minutes
**Requirements**:
@@ -160,7 +164,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
- title = {Axolotl: Post-Training for AI Models},
+ title = {Axolotl: Open Source LLM Post-Training},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
From 58d67bf98ddca63cb082374a04f8b2250ffc2c59 Mon Sep 17 00:00:00 2001
From: salman