From 1b53c49e1a8408ff209ae72a480681d18f7f8c81 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 10 Sep 2025 20:27:00 -0400 Subject: [PATCH] text diffusion training plugin (#3067) * diffusion training plugin * cleanup * nits * fixes + improvements * add back in reinit_weights (clobbered?); masking / pretrain fixes * nits * cleanup; tests draft * sample generation, tests fixes * fixes * nits * add inference support; add auto-mask token support * nits * nits * progress * simplify logging * lint * prefix args with diffusion_ * coderabbito * tests fix * nit * nits * cleanup + nits * nits * fix SFT sample gen * fixes * fix * comments * comments * lint * reward model lora fix * cleanup; fix pretraining_dataset case * gradio inference * update cfgs * update cfgs * train, generation parity, cleanup * fix * simplify * test * test fix --- .pre-commit-config.yaml | 2 +- .../colab-axolotl-example.ipynb | 2 +- examples/llama-3/diffusion/pretrain-1b.yaml | 56 +++ examples/llama-3/diffusion/sft-1b.yaml | 59 +++ src/axolotl/cli/inference.py | 63 ++- src/axolotl/cli/utils/diffusion.py | 375 ++++++++++++++++ src/axolotl/core/builders/causal.py | 15 +- src/axolotl/core/trainers/base.py | 46 +- src/axolotl/integrations/base.py | 2 +- src/axolotl/integrations/config.py | 2 +- src/axolotl/integrations/diffusion/README.md | 154 +++++++ .../integrations/diffusion/__init__.py | 19 + src/axolotl/integrations/diffusion/args.py | 95 ++++ .../integrations/diffusion/callbacks.py | 174 ++++++++ .../integrations/diffusion/generation.py | 409 ++++++++++++++++++ src/axolotl/integrations/diffusion/plugin.py | 41 ++ src/axolotl/integrations/diffusion/trainer.py | 301 +++++++++++++ src/axolotl/integrations/diffusion/utils.py | 159 +++++++ src/axolotl/loaders/adapter.py | 12 +- src/axolotl/loaders/model.py | 118 +++-- src/axolotl/loaders/patch_manager.py | 5 +- src/axolotl/monkeypatch/accelerate/fsdp2.py | 8 +- .../monkeypatch/attention/flex_attn.py | 3 +- src/axolotl/monkeypatch/deepspeed_utils.py | 1 + src/axolotl/utils/config/__init__.py | 2 +- src/axolotl/utils/data/__init__.py | 8 +- src/axolotl/utils/data/sft.py | 2 +- src/axolotl/utils/environment.py | 2 - src/axolotl/utils/schemas/config.py | 6 + src/axolotl/utils/schemas/validation.py | 1 - tests/e2e/test_diffusion.py | 139 ++++++ tests/integrations/test_diffusion.py | 274 ++++++++++++ tests/integrations/test_diffusion_callback.py | 92 ++++ tests/test_streaming.py | 4 +- 34 files changed, 2550 insertions(+), 101 deletions(-) create mode 100644 examples/llama-3/diffusion/pretrain-1b.yaml create mode 100644 examples/llama-3/diffusion/sft-1b.yaml create mode 100644 src/axolotl/cli/utils/diffusion.py create mode 100644 src/axolotl/integrations/diffusion/README.md create mode 100644 src/axolotl/integrations/diffusion/__init__.py create mode 100644 src/axolotl/integrations/diffusion/args.py create mode 100644 src/axolotl/integrations/diffusion/callbacks.py create mode 100644 src/axolotl/integrations/diffusion/generation.py create mode 100644 src/axolotl/integrations/diffusion/plugin.py create mode 100644 src/axolotl/integrations/diffusion/trainer.py create mode 100644 src/axolotl/integrations/diffusion/utils.py create mode 100644 tests/e2e/test_diffusion.py create mode 100644 tests/integrations/test_diffusion.py create mode 100644 tests/integrations/test_diffusion_callback.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92ddc7f41..9c80898ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: rev: v0.12.12 hooks: - id: ruff - args: [--fix] + args: [--fix, --select, I] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index b780a1c48..0e6ba984e 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -176,8 +176,8 @@ } ], "source": [ - "from axolotl.utils.dict import DictDefault\n", "from axolotl.cli.config import load_cfg\n", + "from axolotl.utils.dict import DictDefault\n", "\n", "# Axolotl provides full control and transparency over model and training configuration\n", "config = DictDefault(\n", diff --git a/examples/llama-3/diffusion/pretrain-1b.yaml b/examples/llama-3/diffusion/pretrain-1b.yaml new file mode 100644 index 000000000..8d05e4c60 --- /dev/null +++ b/examples/llama-3/diffusion/pretrain-1b.yaml @@ -0,0 +1,56 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +pretraining_dataset: + - path: wikitext + name: wikitext-103-raw-v1 + type: completion + field: text + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.15 + max_mask_ratio: 0.85 + num_diffusion_steps: 128 + eps: 5e-4 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true + +gradient_accumulation_steps: 8 +micro_batch_size: 4 +max_steps: 10000 +warmup_ratio: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 3e-4 +sdp_attention: true + +bf16: auto +tf32: true + +logging_steps: 1 +save_strategy: steps +save_steps: 1000 + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/diffusion/sft-1b.yaml b/examples/llama-3/diffusion/sft-1b.yaml new file mode 100644 index 000000000..f3b29a809 --- /dev/null +++ b/examples/llama-3/diffusion/sft-1b.yaml @@ -0,0 +1,59 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +val_set_size: 0.05 + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true +eval_sample_packing: true + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 1 +warmup_steps: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 1e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +resume_from_checkpoint: +sdp_attention: true + +logging_steps: 1 +save_strategy: best +eval_strategy: epoch + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index debe57167..30d407713 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -14,6 +14,13 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.cli.args import InferenceCliArgs from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.cli.utils.diffusion import ( + diffusion_inference, + launch_diffusion_gradio_ui, + render_html, + run_diffusion, +) +from axolotl.integrations.base import PluginManager from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -29,6 +36,7 @@ def get_multi_line_input() -> str: Possibly multi-line, possibly empty stdin input as a string. """ print("Give me an instruction (Ctrl + D to submit): ") + print("=" * 80) instruction = "" for line in sys.stdin: @@ -43,9 +51,9 @@ def do_inference( cli_args: InferenceCliArgs, ): """ - Runs inference on the command line in a loop. User input is accepted, a chat template - is (optionally) applied, and the model specified in the `axolotl` config is used to - generate completions according to a default generation config. + Runs inference on the command line in a loop. User input is accepted, a chat + template is (optionally) applied, and the model specified in the `axolotl` config is + used to generate completions according to a default generation config. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -64,16 +72,28 @@ def do_inference( chat_template_str = get_chat_template_from_config( cfg, ds_cfg=None, tokenizer=tokenizer ) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + print("=" * 80) + print("Commands:") + print(":complete N -> completion mode with N tokens (default 64)") + print(":mask R -> random masking with ratio R (0.0–1.0)") + while True: print("=" * 80) - # support for multiline inputs instruction = get_multi_line_input() if not instruction: return @@ -103,9 +123,19 @@ def do_inference( else: batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - print("=" * 40) + print("=" * 80) model.eval() with torch.no_grad(): + if is_diffusion: + diffusion_inference( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + ) + continue + generation_config = GenerationConfig( repetition_penalty=1.1, max_new_tokens=1024, @@ -128,7 +158,7 @@ def do_inference( generation_config=generation_config, streamer=streamer, ) - print("=" * 40) + print("=" * 80) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) @@ -161,13 +191,30 @@ def do_inference_gradio( chat_template_str = get_chat_template_from_config( cfg, ds_cfg=None, tokenizer=tokenizer ) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + launch_diffusion_gradio_ui( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompter_module=prompter_module, + chat_template_str=chat_template_str, + ) + return + def generate(instruction): if not instruction: return diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py new file mode 100644 index 000000000..f83d9077b --- /dev/null +++ b/src/axolotl/cli/utils/diffusion.py @@ -0,0 +1,375 @@ +"""Helpers for diffusion-mode inference in CLI and Gradio.""" + +from __future__ import annotations + +import gradio as gr +import torch +from colorama import Fore, Style + +from axolotl.integrations.diffusion import generate, resolve_mask_token_id +from axolotl.utils.dict import DictDefault + + +def diffusion_inference( + model, + tokenizer, + cfg, + prompt: str, + chat_template_str: str | None = None, +): + """Diffusion inference helper method.""" + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt) + + if cleaned: + prompt = cleaned + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=mode, + target_mask_ratio=target_mask_ratio, + completion_tokens=completion_tokens, + ) + masked_text = info["masked_text"] + mask_ratio = info["mask_ratio"] + generated_ids = info["generated_ids"] + masked_positions = info["masked_positions"] + orig_ids = info["orig_ids"] + + # Display with masked preview and colored diff + if masked_text is not None and mask_ratio is not None: + print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n") + if generated_ids is not None: + # Compute per-token style + styles: list[str] = [] + for i, tid in enumerate(generated_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") # correct fill + elif i < len(orig_ids): + styles.append("red") # incorrect fill + else: + styles.append("normal") # appended + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + # Group contiguous spans by style + styled_spans: list[tuple[str, int, int]] = [] + if generated_ids: + current_style = styles[0] + start = 0 + for i in range(1, len(generated_ids)): + s = styles[i] + if s != current_style: + styled_spans.append((current_style, start, i)) + current_style, start = s, i + styled_spans.append((current_style, start, len(generated_ids))) + + out_parts = [] + for style_name, a, b in styled_spans: + chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + out_parts.append(chunk_text) + print("Generated:\n" + "".join(out_parts)) + else: + print("Generated:\n(no output)") + + +def _parse_commands(text: str): + """ + Parse leading diffusion commands. + + Supported at start of input (can be chained): + :complete N -> completion mode with N tokens (default 64) + :mask R -> random masking with ratio R in [0, 1] + """ + tokens = text.strip().split() + i = 0 + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + consumed = 0 + while i < len(tokens) and tokens[i].startswith(":"): + cmd = tokens[i] + i += 1 + consumed = i + if cmd == ":complete": + mode = "completion" + if i < len(tokens): + try: + completion_tokens = int(tokens[i]) + i += 1 + consumed = i + except Exception: + completion_tokens = 64 + else: + completion_tokens = 64 + elif cmd == ":mask": + mode = "random" + if i < len(tokens): + try: + target_mask_ratio = float(tokens[i]) + i += 1 + consumed = i + except Exception: + target_mask_ratio = None + else: + i -= 1 + consumed = i + break + + cleaned = " ".join(tokens[consumed:]) + + return mode, completion_tokens, target_mask_ratio, cleaned + + +def run_diffusion( + *, + model, + tokenizer, + cfg: DictDefault, + prompt: str, + chat_template_str: str | None, + mode: str = "random", + target_mask_ratio: float | None = None, + completion_tokens: int = 0, +): + """Run a single diffusion generation and return a structured result dict.""" + if chat_template_str: + batch = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False) + + seq = batch["input_ids"].to(cfg.device) + gen_mode = "completion" if mode == "completion" else "random" + comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0 + + result = generate( + model, + tokenizer, + original_sequence=seq[:1], + num_diffusion_steps=cfg.diffusion.num_diffusion_steps, + temperature=cfg.diffusion.generation_temperature, + mask_token_id=int(mask_token_id), + mode=gen_mode, # type: ignore[arg-type] + completion_tokens=comp_tokens, + target_mask_ratio=target_mask_ratio, + ) + + masked_text = result.get("masked") if isinstance(result, dict) else None + mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None + generated_ids = result.get("generated_ids") if isinstance(result, dict) else None + masked_positions = ( + set(result.get("masked_positions") or []) if isinstance(result, dict) else set() + ) + orig_ids = seq[0].detach().cpu().tolist() + + return { + "masked_text": masked_text, + "mask_ratio": mask_ratio, + "generated_ids": generated_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids, + } + + +def render_html( + *, + generated_ids: list[int] | None, + orig_ids: list[int], + masked_positions: set[int], + tokenizer, +) -> str: + """Render HTML visualizing diffusion outputs.""" + if not generated_ids: + return "
Generated:\n(no output)
" + + def _style_for(i: int, tid: int) -> str: + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + return "green" + if i < len(orig_ids): + return "red" + return "normal" + same = i < len(orig_ids) and tid == orig_ids[i] + return "dim" if same else "normal" + + # Group contiguous spans by style to reduce HTML size + spans: list[tuple[str, int, int]] = [] + if generated_ids: + cur = _style_for(0, generated_ids[0]) + start = 0 + for i in range(1, len(generated_ids)): + s = _style_for(i, generated_ids[i]) + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(generated_ids))) + + html_parts = [] + for style_name, a, b in spans: + txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + html_parts.append(f'{txt}') + elif style_name == "red": + html_parts.append(f'{txt}') + elif style_name == "dim": + html_parts.append(f'{txt}') + else: + html_parts.append(txt) + + legend = ( + '
' + 'correct, ' + 'incorrect, ' + 'unchanged' + "
" + ) + + return ( + legend + + '
Generated:\n'
+        + "".join(html_parts)
+        + "
" + ) + + +def launch_diffusion_gradio_ui( + *, + model, + tokenizer, + cfg: DictDefault, + prompter_module=None, + chat_template_str: str | None = None, +): + """Build and launch a simple Gradio UI for diffusion inference.""" + with gr.Blocks( + title=cfg.get("gradio_title", "Axolotl Diffusion Interface") + ) as demo: + gr.Markdown( + """ + ## Axolotl Diffusion Inference + - Mode "Random" masks tokens at a target ratio and fills them. + - Mode "Completion" appends N masked tokens at the end and fills them. + """ + ) + + with gr.Row(): + mode = gr.Radio( + choices=["random", "completion"], + value="random", + label="Mode", + ) + mask_ratio = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.05, + value=0.4, + label="Mask ratio (random mode)", + interactive=True, + ) + completion_tokens = gr.Number( + value=64, + precision=0, + label="Completion tokens (completion mode)", + interactive=True, + visible=False, + ) + + instruction = gr.Textbox(label="Instruction", lines=6) + run_btn = gr.Button("Generate") + + masked_preview = gr.Textbox(label="Masked preview", lines=6) + html_out = gr.HTML(label="Generated") + + def _toggle_controls(selected_mode: str): + return ( + gr.update(visible=(selected_mode == "random")), + gr.update(visible=(selected_mode == "completion")), + ) + + mode.change( + _toggle_controls, + inputs=[mode], + outputs=[mask_ratio, completion_tokens], + ) + + def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int): + if not instruction_text: + return "", "
Generated:\n(no output)
" + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt( + instruction=instruction_text.strip("\n") + ) + ) + else: + prompt = instruction_text.strip() + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=selected_mode, + target_mask_ratio=mratio if selected_mode == "random" else None, + completion_tokens=int(ctoks) if selected_mode == "completion" else 0, + ) + + masked_text = info.get("masked_text") + mask_ratio_val = info.get("mask_ratio") + generated_ids = info.get("generated_ids") + masked_positions = info.get("masked_positions") or set() + orig_ids = info.get("orig_ids") or [] + + preview = ( + f"Masked ({mask_ratio_val:.1%}):\n{masked_text}" + if masked_text is not None and mask_ratio_val is not None + else "" + ) + html = render_html( + generated_ids=generated_ids, + orig_ids=orig_ids, + masked_positions=masked_positions, + tokenizer=tokenizer, + ) + return preview, html + + run_btn.click( + _gen, + inputs=[instruction, mode, mask_ratio, completion_tokens], + outputs=[masked_preview, html_out], + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index ee6383d47..f7f350e1a 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -7,7 +7,11 @@ from pathlib import Path from typing import Type, Union import transformers -from transformers import DataCollatorWithFlattening, EarlyStoppingCallback +from transformers import ( + DataCollatorWithFlattening, + EarlyStoppingCallback, + Trainer, +) from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.builders.base import TrainerBuilderBase @@ -23,15 +27,16 @@ from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( + LossWatchDogCallback, + SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, colab_inference_post_train_callback, log_prediction_callback_factory, - LossWatchDogCallback, - SaveBetterTransformerModelCallback, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -39,7 +44,6 @@ from axolotl.utils.collators import ( MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) -from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger @@ -391,10 +395,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **data_collator_kwargs, ) sig = inspect.signature(trainer_cls) - if "processing_class" in sig.parameters: + if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): trainer_kwargs["processing_class"] = self.tokenizer elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer + if ( trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d7555261f..3427a0b86 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -49,6 +49,13 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) +REDUCTION_FNS = { + "mean": torch.mean, + "min": torch.min, + "max": torch.max, + "sum": torch.sum, +} + class AxolotlTrainer( PackingMixin, @@ -89,7 +96,9 @@ class AxolotlTrainer( super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self._stored_metrics = defaultdict( + lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) + ) if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -585,9 +594,17 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + + for key, metric_data in self._stored_metrics[train_eval].items(): + values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] + reduction_type = metric_data["reduction"] + + fn = REDUCTION_FNS.get(reduction_type) + if fn is None: + raise NotImplementedError( + "Metric reduction must be one of [mean, min, max, sum]" + ) + logs[key] = round(fn(values).item(), 4) if is_main_process(): # Add memory usage @@ -611,10 +628,27 @@ class AxolotlTrainer( return super().log(logs, start_time) def store_metrics( - self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, + metrics: dict[str, float] | dict[str, tuple[int | float, str]], + train_eval: Literal["train", "eval"] = "train", + reduction: Literal["mean", "min", "max", "sum"] = "mean", ) -> None: + """ + Store metrics with specified reduction type. + + Args: + metrics: Dictionary of metric names to values, or metric names to (value, + reduction_type) tuples. + train_eval: Whether this is for training or evaluation. + """ for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) + if isinstance(value, tuple): + value, _reduction = value # type: ignore[assignment] + else: + value, _reduction = value, reduction + + self._stored_metrics[train_eval][key]["values"].append(value) + self._stored_metrics[train_eval][key]["reduction"] = _reduction def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 8edee18a3..c66bc01c6 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -142,7 +142,7 @@ class BasePlugin: model: The loaded model. """ - def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: """Returns a custom class for the trainer. Args: diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index 2217b2819..8ae8aab39 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase def merge_input_args(): diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md new file mode 100644 index 000000000..c27f33de1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/README.md @@ -0,0 +1,154 @@ +# Diffusion LM Training Plugin for Axolotl + +This plugin enables diffusion language model training using an approach inspired by +LLaDA (Large Language Diffusion Models) within Axolotl. + +## Overview + +LLaDA is a diffusion-based approach to language model training that uses: +- **Random token masking** during training instead of next-token prediction +- **Bidirectional attention** to allow the model to attend to the full context +- **Importance weighting** based on masking probabilities for stable training + +This approach can lead to more robust language models with better understanding of +bidirectional context. + +## Installation + +The plugin is included with Axolotl. See our +[installation docs](https://docs.axolotl.ai/docs/installation.html). + +## Quickstart + +Train with an example config (Llama‑3.2 1B): + - Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml` + - SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml` + +### Basic Configuration + +You can also modify your existing configs to enable / customize diffusion training. + +Add the following to your Axolotl config: + +```yaml +# Enable diffusion LM training plugin +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin +``` + +And, configure the nested `diffusion` block (defaults shown): + +```yaml +diffusion: + noise_schedule: linear # or "cosine" + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + + # Mask token (training auto-adds if missing, avoid pad/eos) + mask_token_str: "<|diffusion_mask|>" + # Or use an existing special token id (e.g., 128002 for Llama-3.x) + # mask_token_id: 128002 + + # Sample generation during training (optional) + generate_samples: true + generation_interval: 100 + num_generation_samples: 3 + generation_steps: 128 + generation_temperature: 0.0 + generation_max_length: 100 +``` + +## Supported Models + +Any models that support 4D attention masks should work out of the box. If not, please +create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a +[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)! + +## How It Works + +### Random Masking +During training, tokens are randomly masked: +- Sample timestep `t` uniformly from [0, 1] +- Calculate masking probability: `p = (1 - eps) * t + eps` +- Randomly mask tokens with probability `p` + +### Diffusion Loss + +Loss is computed only on masked tokens with (optional) importance weighting: + +```python +loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens +``` + +## Sample Generation + +When `diffusion.generate_samples: true`, the plugin generates samples during training: + +``` +Sample 1: + Original (45 tokens): The quick brown fox jumps over the lazy dog... + Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]... + Generated: The quick brown fox jumps over the lazy dog... +``` + +Samples are logged to console and wandb (if enabled). + +## Inference + +Diffusion inference is integrated into the standard Axolotl CLI. Use the same config +you trained with and run: + +``` +axolotl inference path/to/your-config.yaml +``` + +Optionally, pass `--gradio` to use a simple web interface. + +Interactive controls (prefix the prompt with commands): +- `:complete N` → completion mode with N new masked tokens appended (default 64) +- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0] + +Example session: + +``` +================================================================================ +Commands: +:complete N -> completion mode with N tokens (default 64) +:mask R -> random masking with ratio R (0.0–1.0) +================================================================================ +Give me an instruction (Ctrl + D to submit): + +:mask 0.4 The quick brown fox jumps over the lazy dog + +Masked (40.0%): +The [MASK] brown [MASK] jumps over the [MASK] dog + +Generated: +The quick brown fox jumps over the loud dog +``` + +## Metrics and Monitoring + +The plugin adds (or modifies) several metrics to track diffusion training: + +- `train/loss`: Weighted diffusion loss +- `train/accuracy`: Accuracy on masked tokens +- `train/mask_ratio`: Average fraction of tokens masked +- `train/num_masked_tokens`: Number of tokens masked +- `train/avg_p_mask`: Average masking probability +- `train/ce_loss`: Unweighted cross-entropy loss +- `train/importance_weight_avg`: Average importance weight + +## Limitations + +- No flash attention support +- No RL training support + +## References + +- [LLaDA Paper](https://arxiv.org/abs/2404.10406) +- [Axolotl Documentation](https://docs.axolotl.ai/) +- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args) diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py new file mode 100644 index 000000000..9e38cc5c1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/__init__.py @@ -0,0 +1,19 @@ +"""Diffusion LM training plugin init.""" + +from .args import DiffusionArgs, DiffusionConfig +from .callbacks import DiffusionGenerationCallback +from .generation import generate +from .plugin import DiffusionPlugin +from .trainer import DiffusionTrainer +from .utils import create_bidirectional_attention_mask, resolve_mask_token_id + +__all__ = [ + "DiffusionArgs", + "DiffusionPlugin", + "DiffusionTrainer", + "generate", + "resolve_mask_token_id", + "create_bidirectional_attention_mask", + "DiffusionGenerationCallback", + "DiffusionConfig", +] diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py new file mode 100644 index 000000000..4f5bfe499 --- /dev/null +++ b/src/axolotl/integrations/diffusion/args.py @@ -0,0 +1,95 @@ +"""Config args for diffusion LM training (nested under `diffusion:`).""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field, model_validator + + +class DiffusionConfig(BaseModel): + """Nested diffusion configuration available under the `diffusion` key.""" + + # Noise schedule config + noise_schedule: Literal["linear", "cosine"] = Field( + default="linear", description="Type of noise schedule for diffusion training" + ) + min_mask_ratio: float = Field( + default=0.1, + ge=0.0, + le=1.0, + description="Minimum masking ratio for diffusion noise schedule", + ) + max_mask_ratio: float = Field( + default=0.9, + ge=0.0, + le=1.0, + description="Maximum masking ratio for diffusion noise schedule", + ) + num_diffusion_steps: int = Field( + default=128, ge=1, description="Number of diffusion timesteps" + ) + eps: float = Field( + default=1e-3, + ge=0.0, + le=1.0, + description="Epsilon value for minimum masking probability in forward process", + ) + + # Training config + importance_weighting: bool = Field( + default=True, + description="Apply importance weighting to loss based on masking probability", + ) + mask_token_id: int | None = Field( + default=None, + description=( + "Token ID to use for masking. Unset by default; can use one of the " + "tokenizer's special tokens here." + ), + ) + mask_token_str: str | None = Field( + default=None, + description=( + "Token string to use as a mask. If `mask_token_id` is invalid or unset, " + "this token will be ensured to exist as an additional special token and " + "used. If absent, a default '<|diffusion_mask|>' will be added." + ), + ) + + # Sample generation config + generate_samples: bool = Field( + default=True, description="Enable sample generation during training" + ) + generation_interval: int = Field( + default=100, ge=1, description="Generate samples every N steps" + ) + num_generation_samples: int = Field( + default=3, ge=1, description="Number of samples to generate each time" + ) + generation_steps: int = Field( + default=128, ge=1, description="Number of diffusion steps for generation" + ) + generation_temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for generation sampling (0.0 = deterministic)", + ) + generation_max_length: int = Field( + default=100, ge=1, description="Maximum sequence length for generation" + ) + + @model_validator(mode="after") + def _validate_mask_ratios(self) -> "DiffusionConfig": + if self.min_mask_ratio > self.max_mask_ratio: + raise ValueError("min_mask_ratio must be ≤ max_mask_ratio") + return self + + +class DiffusionArgs(BaseModel): + """Plugin entry that exposes the nested `diffusion` block to the core config.""" + + diffusion: DiffusionConfig = Field( + default_factory=DiffusionConfig, + description="Diffusion training configuration. Only nested block is supported.", + ) diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py new file mode 100644 index 000000000..18a64023b --- /dev/null +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -0,0 +1,174 @@ +"""Callbacks for diffusion training.""" + +import logging +import sys + +import wandb +from colorama import Fore, Style +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from .generation import generate_samples + +# Simpler logger for more readable sample generation +logger = logging.getLogger(__name__) +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + logger.propagate = False +logger.setLevel(logging.INFO) + + +class DiffusionGenerationCallback(TrainerCallback): + """Callback for generating samples during diffusion training.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + if ( + state.global_step > 0 + and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0 + ): + if not self.trainer.state.is_world_process_zero: + return + + # Use eval dataloader if available, otherwise use train dataloader + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + except Exception: + dataloader = None + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + + # Generate samples + diffusion_cfg = self.trainer.cfg.diffusion + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=diffusion_cfg.num_generation_samples, + max_length=diffusion_cfg.generation_max_length, + num_diffusion_steps=diffusion_cfg.generation_steps, + temperature=diffusion_cfg.generation_temperature, + mask_token_id=diffusion_cfg.mask_token_id, + ) + + # Log samples + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples.""" + if not samples: + return + + logger.info("=" * 60) + logger.info("GENERATED SAMPLES") + logger.info("=" * 60) + + for i, sample_data in enumerate(samples, 1): + original = sample_data["original"] + masked = sample_data["masked"] + generated = sample_data["generated"] + mask_ratio = sample_data["mask_ratio"] + masked_tokens = sample_data["masked_tokens"] + total_tokens = sample_data["total_tokens"] + + logger.info(f"\nSample {i}:") + logger.info(f"\tOriginal ({total_tokens} tokens): {original}") + logger.info( + f"\tMasked ({masked_tokens}/{total_tokens} tokens, " + f"{mask_ratio:.1%}): {masked}" + ) + + try: + gen_ids = sample_data.get("generated_ids") + orig_ids = sample_data.get("orig_ids") + masked_positions = set(sample_data.get("masked_positions") or []) + if isinstance(gen_ids, list) and isinstance(orig_ids, list): + styles: list[str] = [] + for i, tid in enumerate(gen_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") + elif i < len(orig_ids): + styles.append("red") + else: + styles.append("normal") + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + spans: list[tuple[str, int, int]] = [] + if gen_ids: + cur = styles[0] + start = 0 + for i in range(1, len(gen_ids)): + s = styles[i] + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(gen_ids))) + + parts = [] + for style_name, a, b in spans: + chunk_text = self.trainer.processing_class.decode( + gen_ids[a:b], skip_special_tokens=False + ) + if style_name == "green": + parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + parts.append(chunk_text) + logger.info("\tGenerated:\n%s", "".join(parts)) + else: + logger.info(f"\tGenerated: {generated}") + except Exception: + logger.info(f"\tGenerated: {generated}") + + logger.info("=" * 60) + + if self.trainer.cfg.use_wandb: + if wandb.run is not None: + wandb.log( + { + "generated_samples": wandb.Table( + columns=[ + "step", + "original", + "masked", + "generated", + "mask_ratio", + "masked_tokens", + "total_tokens", + ], + data=[ + [ + step, + sample["original"], + sample["masked"], + sample["generated"], + f"{sample['mask_ratio']:.1%}", + sample["masked_tokens"], + sample["total_tokens"], + ] + for sample in samples + ], + ) + }, + step=step, + ) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py new file mode 100644 index 000000000..49e3cdfae --- /dev/null +++ b/src/axolotl/integrations/diffusion/generation.py @@ -0,0 +1,409 @@ +"""Sample generation utilities for diffusion training.""" + +import re +from typing import Any, List, Literal, Optional + +import torch + +from axolotl.utils.logging import get_logger + +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +def generate_samples( + model: torch.nn.Module, + tokenizer: Any, + dataloader: Optional[Any] = None, + num_generation_samples: int = 3, + max_length: int = 100, + num_diffusion_steps: int = 128, + temperature: float = 0.0, + mask_token_id: int = 32000, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, +) -> List[dict]: + """ + Generate text samples using the diffusion model by randomly masking sequences from + the given dataset and running the reverse diffusion process. + + Args: + model: The wrapped or unwrapped model + tokenizer: Tokenizer for encoding/decoding + dataloader: Validation dataloader (for sampling sequences) + num_generation_samples: Number of samples to generate + max_length: Maximum length of sequences to use + num_diffusion_steps: Number of diffusion steps for generation + temperature: Temperature for sampling (0.0 = deterministic) + mask_token_id: Token ID used for masking + + Returns: + List of dictionaries with original text, masked text, and generated text + """ + if dataloader is None: + LOG.warning("No validation dataloader provided, cannot generate samples") + return [] + + unwrapped_model = model.module if hasattr(model, "module") else model + training = unwrapped_model.training + unwrapped_model.eval() + + # Resolve device robustly (some modules don't expose `.device`) + device = getattr(unwrapped_model, "device", None) + if device is None: + try: + device = next(unwrapped_model.parameters()).device + except StopIteration: + device = torch.device("cpu") + generations = [] + + # Sample sequences from validation dataset + sampled_sequences = _sample_sequences_from_dataloader( + dataloader, num_generation_samples, max_length, device + ) + LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") + + # Generate samples using reverse diffusion process + with torch.no_grad(): + for sample in sampled_sequences: + if isinstance(sample, dict): + original_sequence = sample.get("input_ids") + labels_seq = sample.get("labels") + attn_seq = sample.get("attention_mask") + else: + original_sequence = sample + labels_seq = None + attn_seq = None + generation_result = generate( + unwrapped_model, + tokenizer, + original_sequence, + num_diffusion_steps, + temperature, + mask_token_id, + mode=mode, + completion_tokens=completion_tokens, + target_mask_ratio=target_mask_ratio, + labels=labels_seq, + attention_mask=attn_seq, + ) + generations.append(generation_result) + + # Restore prior training state + if training: + unwrapped_model.train() + else: + unwrapped_model.eval() + + return generations + + +def _sample_sequences_from_dataloader( + dataloader: Any, num_samples: int, max_length: int, device: torch.device +) -> List[Any]: + """Sample sequences from validation dataloader.""" + sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = [] + sample_count = 0 + + # Skip a random number of batches (we could be more clever about this) + skip_batches = torch.randint(0, 10, (1,)).item() + batch_count = 0 + + for batch in dataloader: + # Skip some batches for variety + if batch_count < skip_batches: + batch_count += 1 + continue + + if sample_count >= num_samples: + break + + batch_count += 1 + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask") + labels = batch.get("labels") + + # Randomly sample from sequences in this batch + batch_indices = torch.randperm(input_ids.size(0)).tolist() + + for i in batch_indices: + if sample_count >= num_samples: + break + + # Get actual sequence length (non-padded) + if attention_mask is not None: + seq_len = attention_mask[i].sum().item() + else: + seq_len = input_ids.size(1) + + if seq_len < 10: + continue + + # Determine truncation length + max_total = min(seq_len, max_length) + if labels is not None: + labels_i = labels[i][:seq_len] + answer_mask = labels_i != -100 + if not answer_mask.any(): + # No answer tokens; skip for SFT masking + continue + first_ans_idx = int( + torch.nonzero(answer_mask, as_tuple=False)[0].item() + ) + prompt_len = first_ans_idx + if prompt_len >= max_total: + # Prompt alone reaches cap; cannot include any answer + continue + remaining_answer = int(answer_mask[prompt_len:].sum().item()) + allowed_answer = max_total - prompt_len + take_answer = min(remaining_answer, allowed_answer) + if take_answer <= 0: + continue + actual_length = prompt_len + take_answer + else: + actual_length = max_total + + # Extract the (possibly truncated) sequence + sequence = input_ids[i][:actual_length].unsqueeze(0).to(device) + attn_seq = ( + attention_mask[i][:actual_length].unsqueeze(0).to(device) + if attention_mask is not None + else None + ) + if labels is not None: + labels_seq = labels[i][:actual_length].unsqueeze(0).to(device) + sampled_sequences.append( + { + "input_ids": sequence, + "labels": labels_seq, + "attention_mask": attn_seq, + } + ) + else: + if attn_seq is not None: + sampled_sequences.append( + {"input_ids": sequence, "attention_mask": attn_seq} + ) + else: + sampled_sequences.append(sequence) + sample_count += 1 + + return sampled_sequences + + +def generate( + model: torch.nn.Module, + tokenizer: Any, + original_sequence: torch.Tensor, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + *, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, + labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> dict: + """Generate a single sample using reverse diffusion.""" + # Get original text for comparison + original_text = tokenizer.decode( + original_sequence[0].cpu(), skip_special_tokens=True + ) + + # Build masked sequence + if ( + labels is not None + and labels.numel() > 0 + and (labels == -100).any() + and (labels != -100).any() + ): + # SFT case: completely mask all answer tokens (labels != -100) + total_tokens = original_sequence.size(1) + masked_indices = (labels != -100).to(dtype=torch.bool) + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + masked_tokens = int(masked_indices.sum().item()) + mask_ratio = masked_tokens / max(int(total_tokens), 1) + elif mode == "completion" and completion_tokens > 0: + # Append mask tokens to the right for completion + total_tokens = original_sequence.size(1) + int(completion_tokens) + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, -int(completion_tokens) :] = True + + append = torch.full( + (1, int(completion_tokens)), mask_token_id, device=original_sequence.device + ) + masked_sequence = torch.cat([original_sequence, append], dim=1) + masked_tokens = int(completion_tokens) + mask_ratio = masked_tokens / total_tokens + else: + # Apply random masking with optional fixed ratio + total_tokens = original_sequence.size(1) + if target_mask_ratio is None: + min_ratio, max_ratio = 0.1, 0.7 + target_mask_ratio = ( + torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio + ) + target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio))) + + # Create random mask indices + mask_positions = torch.randperm(total_tokens)[:target_masked_tokens] + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, mask_positions] = True + + # Create masked sequence + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + + # Calculate actual mask ratio + masked_tokens = masked_indices.sum().item() + mask_ratio = masked_tokens / total_tokens + + # Get masked text for comparison + masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False) + masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id) + + # Run reverse diffusion process + sequence = masked_sequence.clone() + attention_mask = create_bidirectional_attention_mask( + sequence, attention_mask, sample_packing=attention_mask is not None + ) + for step in range(num_diffusion_steps): + sequence = _diffusion_step( + model, + sequence, + step, + num_diffusion_steps, + temperature, + mask_token_id, + attention_mask, + ) + generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True) + + # Collect diagnostic info + final_ids = sequence[0].detach().cpu().tolist() + orig_ids_for_render = original_sequence[0].detach().cpu().tolist() + if masked_indices is not None: + masked_positions = ( + torch.where(masked_indices[0])[0].detach().cpu().tolist() + if masked_indices.ndim == 2 + else [] + ) + else: + masked_positions = [] + + result = { + "original": original_text, + "masked": masked_text, + "generated": generated_text, + "mask_ratio": mask_ratio, + "masked_tokens": masked_tokens, + "total_tokens": total_tokens, + "generated_ids": final_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids_for_render, + "formatted": ( + f"Original: '{original_text}' → Masked: '{masked_text}' " + f"({mask_ratio:.1%}) → Generated: '{generated_text}'" + ), + } + + return result + + +def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: + """Clean up masked text for display.""" + mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) + cleaned = masked_text.replace(mask_token_repr, "[MASK]") + + # Remove literal special token strings + if hasattr(tokenizer, "special_tokens_map"): + for token_value in tokenizer.special_tokens_map.values(): + if token_value and isinstance(token_value, str): + cleaned = cleaned.replace(token_value, "") + + # Normalize whitespace but preserve newlines + cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n") + cleaned = re.sub(r"[ \t]+", " ", cleaned) + cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip() + return cleaned + + +def _diffusion_step( + model: torch.nn.Module, + sequence: torch.Tensor, + step: int, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Perform a single diffusion step with remasking.""" + # Only process if there are masked tokens remaining + current_mask = sequence == mask_token_id + if not current_mask.any(): + return sequence + + # Create or use provided attention mask + if attention_mask is None: + batch_size, seq_len = sequence.shape + attention_mask = torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device + ) + + # Forward pass + outputs = model(input_ids=sequence, attention_mask=attention_mask) + logits = outputs.logits + + # Only sample at currently masked positions + if current_mask.any(): + masked_logits = logits[current_mask] + + # Apply temperature scaling + if temperature > 0: + scaled_logits = masked_logits / temperature + else: + scaled_logits = masked_logits + + # Suppress mask token in outputs + scaled_logits[:, mask_token_id] = -float("inf") + + if temperature > 0: + # Add Gumbel noise for sampling + gumbel_noise = -torch.log( + -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32)) + ) + gumbel_logits = scaled_logits + gumbel_noise + predicted_tokens = torch.argmax(gumbel_logits, dim=-1) + else: + predicted_tokens = torch.argmax(scaled_logits, dim=-1) + + # Calculate probabilities for confidence scoring + probs = torch.softmax(scaled_logits, dim=-1) + predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens] + + # Determine how many tokens to unmask this step + remaining_masked = current_mask.sum().item() + if step == num_diffusion_steps - 1: + num_to_unmask = remaining_masked + else: + unmask_ratio = 1.0 / (num_diffusion_steps - step) + num_to_unmask = max(1, int(remaining_masked * unmask_ratio)) + + # Select highest confidence predictions to unmask + if num_to_unmask >= remaining_masked: + sequence[current_mask] = predicted_tokens + else: + _, top_indices = predicted_token_probs.topk(num_to_unmask) + mask_positions = torch.where(current_mask)[1] + positions_to_unmask = mask_positions[top_indices] + sequence[0, positions_to_unmask] = predicted_tokens[top_indices] + + return sequence diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py new file mode 100644 index 000000000..c31f48b03 --- /dev/null +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -0,0 +1,41 @@ +"""Diffusion LM training plugin for Axolotl.""" + +from peft import PeftModel +from transformers import PreTrainedModel + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .trainer import DiffusionTrainer + +LOG = get_logger(__name__) + + +class DiffusionPlugin(BasePlugin): + """ + Plugin for diffusion language model training. + + This plugin enables diffusion-based training using the LLaDA approach, which uses + random masking and bidirectional attention to train language models. + """ + + def __init__(self): + super().__init__() + self.cfg = None + + def get_input_args(self) -> str: + """Returns the pydantic model for LLaDA plugin arguments.""" + return "axolotl.integrations.diffusion.DiffusionArgs" + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Perform actions after model is loaded.""" + self.cfg = cfg + + def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: + """Return custom trainer class for diffusion training.""" + return DiffusionTrainer + + def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): + """Configure trainer after creation.""" + trainer.set_config(cfg) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py new file mode 100644 index 000000000..42b2468f4 --- /dev/null +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -0,0 +1,301 @@ +"""Custom trainer for diffusion LM training.""" + +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import nn + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .callbacks import DiffusionGenerationCallback +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +class DiffusionTrainer(AxolotlTrainer): + """Custom trainer for diffusion LM training that overrides loss computation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = None + self._special_token_ids = None + + def set_config(self, config: DictDefault): + """Set config for diffusion training.""" + self.cfg = config + self._cache_special_token_ids() + self._resolve_mask_token_id() + + token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0)) + LOG.info(f"Diffusion: using mask_token_id={token_id}") + + if getattr(config.diffusion, "generate_samples", True): + generation_callback = DiffusionGenerationCallback(self) + self.add_callback(generation_callback) + + def _resolve_mask_token_id(self) -> None: + """Ensure mask_token_id is valid for the current tokenizer.""" + from .utils import resolve_mask_token_id + + tokenizer = getattr(self, "processing_class", None) + if tokenizer is None: + return + + mid = resolve_mask_token_id( + tokenizer, + self.cfg, + allow_add=True, + model=getattr(self, "model", None), + ) + try: + self.cfg.diffusion.mask_token_id = int(mid) + except Exception: + pass + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor], + return_outputs: bool = False, + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Override compute_loss to use diffusion loss.""" + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + labels = inputs.get("labels") + + if input_ids is None: + raise ValueError("input_ids is required for diffusion training") + + loss, outputs = self._compute_diffusion_loss( + model, input_ids, attention_mask, labels + ) + + if return_outputs: + return loss, outputs + return loss + + def _cache_special_token_ids(self): + """Cache special token IDs to avoid repeated tokenizer access.""" + if self.processing_class is None: + self._special_token_ids = set() + return + + tokenizer = self.processing_class + special_tokens = set() + + if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: + special_tokens.add(tokenizer.bos_token_id) + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + special_tokens.add(tokenizer.eos_token_id) + if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + special_tokens.add(tokenizer.pad_token_id) + + self._special_token_ids = special_tokens + + def _forward_process( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + eps: float = 1e-3, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward noising process. A timestep is sampled along the process, and tokens are + masked with probability determined by the configured noise schedule. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + eps: Small epsilon value for minimum masking probability. + + Returns: + noisy_batch: Input with some tokens masked. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Sample random timesteps for each sample in batch + t = torch.rand(batch_size, device=device) + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens if attention_mask is provided + if attention_mask is not None: + valid_mask = attention_mask.bool() + p_mask = p_mask * valid_mask.float() + + # Create mask to exclude special tokens + special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) + if self._special_token_ids: + for token_id in self._special_token_ids: + special_token_mask |= input_ids == token_id + + # Create random mask based on p_mask + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + masked_indices = masked_indices & ~special_token_mask + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask + + # Create masked input + mask_token_id = int(self.cfg.diffusion.mask_token_id) + mask_value = torch.full_like(input_ids, mask_token_id) + noisy_batch = torch.where(masked_indices, mask_value, input_ids) + + return noisy_batch, masked_indices, p_mask + + def _compute_diffusion_loss( + self, + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | Any]: + """ + Compute diffusion loss. + + Args: + model: The model to compute loss for. + input_ids: Ground truth token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + + Returns: + loss: Cross-entropy loss. + metrics: Dictionary of metrics. + """ + # Short-circuit empty sequences + if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0: + zero = torch.tensor( + 0.0, + device=(input_ids.device if input_ids is not None else None), + requires_grad=True, + ) + return zero, {} + + # If an attention_mask is provided and all positions are padding for every + # sample in this batch, skip the step. + if attention_mask is not None: + if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all(): + zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + return zero, {} + + # Apply forward process + noisy_batch, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, labels, self.cfg.diffusion.eps + ) + + # Create bidirectional attention mask + bidirectional_mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=self.cfg.sample_packing + ) + + # Forward pass + outputs = model( + input_ids=noisy_batch.long(), + attention_mask=bidirectional_mask, + ) + logits = outputs.logits + + if masked_indices.sum() > 0: + valid_indices = torch.where(masked_indices) + batch_indices, seq_indices = valid_indices + + masked_logits = logits[batch_indices, seq_indices] + masked_targets = input_ids[batch_indices, seq_indices] + masked_p_mask = p_mask[batch_indices, seq_indices] + + # Compute cross-entropy loss without reduction + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + if self.cfg.diffusion.importance_weighting: + masked_p_mask = masked_p_mask.float() + weighted_loss = token_loss / masked_p_mask + else: + weighted_loss = token_loss + + if labels is not None: + # For SFT data: normalize by answer token count per sample + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] + + # Get batch indices for masked tokens + masked_batch_indices = batch_indices + + # Sum losses per sample and divide by answer length + batch_size = input_ids.shape[0] + loss_per_sample = torch.zeros(batch_size, device=input_ids.device) + for i in range(batch_size): + sample_mask = masked_batch_indices == i + if sample_mask.sum() > 0: + sample_loss = weighted_loss[sample_mask].sum() + denom = answer_lengths[i].clamp(min=1.0) + loss_per_sample[i] = sample_loss / denom + + loss = loss_per_sample.mean() + else: + # Non-SFT: when importance weighting is enabled, use unbiased estimator + # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens + # for stable scaling across varying mask ratios. + if self.cfg.diffusion.importance_weighting: + loss = weighted_loss.sum() / ( + input_ids.shape[0] * input_ids.shape[1] + ) + else: + loss = weighted_loss.mean() + + ce_loss = token_loss.mean() + + # Compute accuracy on masked tokens + with torch.no_grad(): + pred_tokens = masked_logits.argmax(dim=-1) + accuracy = (pred_tokens == masked_targets).float().mean() + else: + loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + accuracy = torch.tensor(0.0, device=input_ids.device) + ce_loss = torch.tensor(0.0, device=input_ids.device) + masked_p_mask = torch.tensor(1.0, device=input_ids.device) + + avg_p_mask = ( + p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0 + ) + metrics = { + "loss": loss.item(), + "accuracy": accuracy.item(), + "mask_ratio": masked_indices.float().mean().item(), + "num_masked_tokens": (masked_indices.sum().item(), "sum"), + "avg_p_mask": avg_p_mask, + "ce_loss": ce_loss.item(), + } + + # If doing SFT training, log answer-specific metrics + if self.cfg.datasets is not None: + with torch.no_grad(): + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # type: ignore + total_answer_tokens = answer_mask.sum().item() # type: ignore + total_tokens = labels.numel() # type: ignore + metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1) + metrics["avg_answer_length"] = answer_lengths.mean().item() + + if self.cfg.diffusion.importance_weighting: + metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() + + train_eval: Literal["train", "eval"] = "train" if model.training else "eval" + self.store_metrics(metrics, train_eval=train_eval) + + return loss, outputs diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py new file mode 100644 index 000000000..47abf6fec --- /dev/null +++ b/src/axolotl/integrations/diffusion/utils.py @@ -0,0 +1,159 @@ +"""Shared utilities for diffusion integration.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from axolotl.utils.dict import DictDefault + + +def resolve_mask_token_id( + tokenizer: Any, + cfg: DictDefault, + *, + allow_add: bool, + model: Any | None = None, + default_token: str = "<|diffusion_mask|>", +) -> int: + """Resolve mask token id. Training may add a new special token; inference won't.""" + # Determine vocab size if available + vocab_size = None + if tokenizer is not None: + if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None: + try: + vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type] + except Exception: + vocab_size = None + elif hasattr(tokenizer, "__len__"): + try: + vocab_size = int(len(tokenizer)) + except Exception: + vocab_size = None + + # Use explicit id from config if provided + diffusion_cfg = getattr(cfg, "diffusion", None) + # Fallback to top-level attr names only if nested missing (shouldn't happen) + cfg_id = ( + getattr(diffusion_cfg, "mask_token_id", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_id", None) + ) + if isinstance(cfg_id, int) and cfg_id >= 0: + if vocab_size is None or cfg_id < vocab_size: + return int(cfg_id) + + def _existing_special_token_id(token_str: str | None) -> int | None: + """Attempt to resolve an existing special token string to a real ID.""" + if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"): + return None + try: + token_id = tokenizer.convert_tokens_to_ids(token_str) + except Exception: + return None + + if not isinstance(token_id, int) or token_id < 0: + return None + + # Ensure it's registered as special and not UNK, and within vocab + unk_id = getattr(tokenizer, "unk_token_id", None) + specials = set(getattr(tokenizer, "all_special_tokens", []) or []) + addl = set(getattr(tokenizer, "additional_special_tokens", []) or []) + is_special = token_str in specials or token_str in addl + in_vocab = vocab_size is None or token_id < vocab_size + if ( + (unk_id is not None and token_id == unk_id) + or not is_special + or not in_vocab + ): + return None + return token_id + + # Try mask token string if provided + token_str = ( + getattr(diffusion_cfg, "mask_token_str", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_str", None) + ) + for candidate in (token_str, default_token): + token_id = _existing_special_token_id(candidate) + if isinstance(token_id, int): + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(token_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(token_id) + except Exception: + pass + return int(token_id) + + # Optionally add and return a dedicated special token during training + if allow_add and hasattr(tokenizer, "add_special_tokens"): + token_to_add = token_str or default_token + try: + tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]}) + + # Resize embeddings if possible + if ( + model is not None + and hasattr(tokenizer, "__len__") + and hasattr(model, "resize_token_embeddings") + ): + try: + model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass + new_id = tokenizer.convert_tokens_to_ids(token_to_add) + if isinstance(new_id, int) and new_id >= 0: + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(new_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(new_id) + except Exception: + pass + return int(new_id) + except Exception: + pass + + # Fallback to unk or 0 (do not update cfg) + fallback = getattr(tokenizer, "unk_token_id", 0) or 0 + return int(fallback) + + +def create_bidirectional_attention_mask( + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sample_packing: bool = False, +) -> torch.Tensor: + """ + Create bidirectional attention mask to override default causal masking. + Handles sample-packed sequences where different samples are identified + by different attention mask values. + + Args: + input_ids: Input token ids [batch_size, seq_len] + attention_mask: Attention mask [batch_size, seq_len] + sample_packing: Whether sample packing is enabled + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len] + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + if attention_mask is None or not sample_packing: + return torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device + ) + + # Handle sample packing: tokens can only attend within their sample + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + + # Add head dimension: [batch_size, 1, seq_len, seq_len] + return bidirectional_mask.unsqueeze(1) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 989b34aee..bcde4bf96 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -14,6 +14,7 @@ from peft import ( PeftConfig, PeftMixedModel, PeftModel, + TaskType, get_peft_model, ) from transformers import PreTrainedModel @@ -101,6 +102,15 @@ def load_lora( if cfg.peft_trainable_token_indices: lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices + # Determine the correct PEFT task type + model_cls = type(model).__name__ + if "SequenceClassification" in model_cls: + task_type = TaskType.SEQ_CLS + elif "TokenClassification" in model_cls: + task_type = TaskType.TOKEN_CLS + else: + task_type = TaskType.CAUSAL_LM + lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, @@ -112,7 +122,7 @@ def load_lora( fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", - task_type="CAUSAL_LM", + task_type=task_type, **lora_config_kwargs, ) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index a9507d685..f438d6b61 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -673,6 +673,33 @@ class ModelLoader: return hf_ds_cfg + def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel: + """ + Load model with random initialization using from_config. + + Uses the selected loader when provided; otherwise falls back to the auto loader. + """ + loader = model_loader_class or self.auto_model_loader + if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]: + model = loader.from_config( + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + else: + model = loader(config=self.model_config) + + return model + + def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel: + """Load model from pretrained weights.""" + loader = model_loader_class or self.auto_model_loader + kwargs = { + "config": self.model_config, + "trust_remote_code": self.cfg.trust_remote_code or False, + **self.model_kwargs, + } + return loader.from_pretrained(self.base_model, **kwargs) + def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False @@ -687,7 +714,8 @@ class ModelLoader: if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: skip_move_to_device = True - # Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map + # Don't delete device_map for QLoRA + FSDP - it was set correctly in + # _set_device_map if ( "device_map" in self.model_kwargs and not self.is_qlora_and_fsdp_enabled @@ -716,6 +744,11 @@ class ModelLoader: or self.cfg.qlora_sharded_model_loading ) ): + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with sharded quantized loading. " + "Loading from pretrained weights instead." + ) quant_storage = self.cfg.torch_dtype quantization_config = getattr( self.model_config, "quantization_config", None @@ -731,33 +764,12 @@ class ModelLoader: quantization_config=quantization_config, ) skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # Please don't remove underscore binding without reading the fn docstring. - _ = self._configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader(config=self.model_config) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) elif self.model_type == "MambaLMHeadModel": + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with MambaLMHeadModel. " + "Loading from pretrained weights instead." + ) # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() @@ -770,41 +782,27 @@ class ModelLoader: self.base_model, **self.model_kwargs, ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - elif self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) else: - # Please don't remove underscore binding without reading the fn docstring. + # Please don't remove underscore binding without reading the fn docstring _ = self._configure_zero3_memory_efficient_loading() - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) + + if ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # Use model type from transformers + model_loader_class = getattr(transformers, self.model_type) + else: + # Use auto model loader (handles gptq and default cases) + model_loader_class = self.auto_model_loader + + if self.cfg.reinit_weights: + self.model = self._load_model_from_config(model_loader_class) + else: + self.model = self._load_model_from_pretrained(model_loader_class) + if is_deepspeed_zero3_enabled(): skip_move_to_device = True diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 044c278a3..a5a630cb5 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -3,8 +3,8 @@ Applies pre- and post-model load patches for various fixes and optimizations. """ -import os import importlib.util +import os from functools import cached_property import addict @@ -468,9 +468,10 @@ class PatchManager: def _apply_patch_deepspeed_zero3(self): try: - from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches + if self.cfg.activation_offloading is True and ( is_deepspeed_zero3_enabled() or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 3b38a33b7..d8ba02cb2 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -160,9 +160,11 @@ def get_state_dict(self, model, unwrap=True): state_dict[param_name] = param.cpu() torch.distributed.barrier() elif self.distributed_type == DistributedType.FSDP: - from torch.distributed.fsdp import FullStateDictConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, + ) full_state_dict_config = FullStateDictConfig( offload_to_cpu=True, rank0_only=True diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 65ccad533..678f65bee 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,11 +1,12 @@ """Flex attention monkey patch""" import sys -from packaging import version import torch import transformers +from packaging import version from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/deepspeed_utils.py b/src/axolotl/monkeypatch/deepspeed_utils.py index 6740f556b..d7e69e112 100644 --- a/src/axolotl/monkeypatch/deepspeed_utils.py +++ b/src/axolotl/monkeypatch/deepspeed_utils.py @@ -1,5 +1,6 @@ import importlib import importlib.util + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index f40fe6687..7a2bbd6f9 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -17,8 +17,8 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset LOG = get_logger(__name__) diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 788f13638..8b9e4e91d 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,14 +1,14 @@ """Init for `axolotl.utils.data` module.""" -from axolotl.utils.data.streaming import ( - encode_streaming, - wrap_streaming_dataset, -) from axolotl.utils.data.rl import prepare_preference_datasets from axolotl.utils.data.sft import ( get_dataset_wrapper, prepare_datasets, ) +from axolotl.utils.data.streaming import ( + encode_streaming, + wrap_streaming_dataset, +) from axolotl.utils.data.utils import md5 __all__ = [ diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 28732e01d..ba5aec2d6 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -16,7 +16,6 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from axolotl.prompters import Prompter from axolotl.utils.data.lock import FileLockLoader -from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.shared import ( create_train_validation_split, datasets_with_name_generator, @@ -27,6 +26,7 @@ from axolotl.utils.data.shared import ( save_preprocessed_dataset, try_load_from_hub, ) +from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, handle_long_seq_in_dataset, diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 751f7e253..192aca4e1 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -6,8 +6,6 @@ from importlib.metadata import version from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, -) -from accelerate.utils.environment import ( get_gpu_info, ) from packaging.version import Version, parse diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e4c1fdf29..d612ec8a5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -106,6 +106,12 @@ class AxolotlInputConfig( "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" }, ) + reinit_weights: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Reinitialize model weights randomly instead of loading pretrained weights" + }, + ) trainer_cls: str | None = Field( default=None, diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 49add8081..64018ca48 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -14,7 +14,6 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType - LOG = get_logger(__name__) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py new file mode 100644 index 000000000..cc3d8070b --- /dev/null +++ b/tests/e2e/test_diffusion.py @@ -0,0 +1,139 @@ +"""E2E smoke test for diffusion training plugin.""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +class TestDiffusion: + """Test case for diffusion training plugin.""" + + def test_diffusion_smoke_test(self, temp_dir): + """ + Smoke test for diffusion training to ensure the plugin loads and trains without + error. + """ + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 3, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": False, + }, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + def test_diffusion_sft_labels(self, temp_dir): + """Test that diffusion training properly handles SFT data with labels.""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 2, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": True, + }, + # Ensure we have proper SFT labels + "train_on_inputs": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + # Verify that the dataset has labels + sample = dataset_meta.train_dataset[0] + assert "labels" in sample, "SFT dataset should have labels" + + # Check that some labels are -100 (prompt tokens) + labels = sample["labels"] + if hasattr(labels, "tolist"): + labels = labels.tolist() + assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py new file mode 100644 index 000000000..141d8d150 --- /dev/null +++ b/tests/integrations/test_diffusion.py @@ -0,0 +1,274 @@ +"""Tests for diffusion trainer integration.""" + +# pylint: disable=redefined-outer-name,protected-access + +from unittest.mock import Mock + +import pytest +import torch + +from axolotl.integrations.diffusion import DiffusionTrainer +from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + tokenizer.pad_token_id = 0 + return tokenizer + + +@pytest.fixture +def diffusion_config(): + """Create a diffusion config.""" + return DictDefault( + { + "diffusion": { + "mask_token_id": 32000, + "eps": 1e-3, + "importance_weighting": False, + }, + "sample_packing": False, + } + ) + + +@pytest.fixture +def diffusion_trainer_instance(mock_tokenizer, diffusion_config): + """Create a diffusion trainer instance for testing methods directly.""" + # Create a minimal trainer instance just for testing methods + trainer = object.__new__(DiffusionTrainer) # Bypass __init__ + trainer.cfg = diffusion_config + trainer._special_token_ids = {0, 1, 2} # pad, bos, eos + trainer.processing_class = mock_tokenizer + trainer.store_metrics = Mock() # Mock metrics storage + return trainer + + +class TestDiffusionTrainer: + """Test the DiffusionTrainer class.""" + + def test_forward_process_basic(self, diffusion_trainer_instance): + """Test basic forward process without labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process(input_ids, eps=0.1) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that special tokens are not masked + special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) + assert not masked_indices[special_token_positions].any() + + # Check that mask token is applied + mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id + masked_positions = masked_indices + if masked_positions.any(): + assert (noisy_batch[masked_positions] == mask_token_id).all() + + def test_forward_process_with_labels(self, diffusion_trainer_instance): + """Test forward process with SFT labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process( + input_ids, labels=labels, eps=0.1 + ) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that only answer tokens can be masked (where labels != -100) + non_answer_mask = labels == -100 + + # No masking should occur on non-answer tokens + assert not masked_indices[non_answer_mask].any() + + # p_mask should be the same for all positions (sampled timestep), + # but masking is only applied to answer tokens + assert p_mask.shape == input_ids.shape + # Verify that masked_indices respects the answer mask + assert not masked_indices[non_answer_mask].any() + + def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): + """Test forward process with attention mask.""" + input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + + _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( + input_ids, attention_mask=attention_mask, eps=0.1 + ) + + # Check that padding tokens are not masked + padding_positions = attention_mask == 0 + assert not masked_indices[padding_positions].any() + assert (p_mask[padding_positions] == 0).all() + + def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): + """Test bidirectional attention mask without sample packing.""" + input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask(input_ids) + + # Should be all-to-all attention + expected_shape = (1, 1, 4, 4) + assert mask.shape == expected_shape + assert mask.all() + + def test_bidirectional_attention_mask_with_packing( + self, diffusion_trainer_instance + ): + """Test bidirectional attention mask with sample packing.""" + diffusion_trainer_instance.cfg.sample_packing = True + input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) + # Sample IDs: first sample (1), second sample (2) + attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=True + ) + + # Check that tokens within same sample can attend to each other + # but not across samples + assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other + assert mask[0, 0, 1, 2].item() + assert not mask[0, 0, 0, 3].item() # Can't attend across samples + assert not mask[0, 0, 2, 4].item() + assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other + + def test_compute_loss_basic(self, diffusion_trainer_instance): + """Test basic loss computation.""" + # Mock model that returns logits + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + assert outputs == mock_outputs + + # Check that metrics were stored + diffusion_trainer_instance.store_metrics.assert_called_once() + + def test_compute_loss_sft(self, diffusion_trainer_instance): + """Test loss computation with SFT labels.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + diffusion_trainer_instance.cfg.datasets = Mock() + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids, labels=labels + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + + # Check that SFT metrics were added + call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] + assert "answer_ratio" in call_args + assert "avg_answer_length" in call_args + + def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): + """Test loss computation when no tokens are masked.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 3 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + # Only special tokens (which won't be masked) + input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Loss should be zero when no tokens are masked + assert loss.item() == 0.0 + assert loss.requires_grad + + def test_cache_special_token_ids(self, mock_tokenizer): + """Test caching of special token IDs.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = mock_tokenizer + trainer._cache_special_token_ids() + assert trainer._special_token_ids == {0, 1, 2} + + def test_cache_special_token_ids_no_tokenizer(self): + """Test caching when no tokenizer is available.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = None + trainer._cache_special_token_ids() + + assert trainer._special_token_ids == set() + + def test_main_compute_loss_interface(self, diffusion_trainer_instance): + """Test the main compute_loss interface.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + mock_outputs.logits = torch.randn(1, 5, 1000) + mock_model.return_value = mock_outputs + mock_model.training = True + + inputs = { + "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), + "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), + } + + # Test without return_outputs + loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) + assert isinstance(loss, torch.Tensor) + + # Test with return_outputs + loss, outputs = diffusion_trainer_instance.compute_loss( + mock_model, inputs, return_outputs=True + ) + assert isinstance(loss, torch.Tensor) + assert outputs == mock_outputs + + def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): + """Test that missing input_ids raises ValueError.""" + mock_model = Mock() + inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} + + with pytest.raises(ValueError, match="input_ids is required"): + diffusion_trainer_instance.compute_loss(mock_model, inputs) diff --git a/tests/integrations/test_diffusion_callback.py b/tests/integrations/test_diffusion_callback.py new file mode 100644 index 000000000..3e8785fe0 --- /dev/null +++ b/tests/integrations/test_diffusion_callback.py @@ -0,0 +1,92 @@ +"""Tests for diffusion generation callback dataloader selection and triggering.""" + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from axolotl.integrations.diffusion import DiffusionGenerationCallback + + +class DummyTrainer: + """Minimal trainer double with required attributes/methods for the callback.""" + + def __init__(self, use_eval: bool): + # Config used by callback + self.cfg = SimpleNamespace( + diffusion=SimpleNamespace( + generation_interval=1, + num_generation_samples=1, + generation_max_length=32, + generation_steps=4, + generation_temperature=0.0, + mask_token_id=16, + ), + use_wandb=False, + ) + + # Model/tokenizer are passed through to generate_samples; not used here + self.model = Mock() + self.processing_class = Mock() + + # Datasets and loaders + self.eval_dataset = object() if use_eval else None + self._train_loader = object() + self._eval_loader = object() + + # State for world process check + self.state = SimpleNamespace(is_world_process_zero=True) + + # Track which loader was requested + self.requested: list[str] = [] + + def get_train_dataloader(self): + self.requested.append("train") + return self._train_loader + + def get_eval_dataloader(self): + self.requested.append("eval") + return self._eval_loader + + +@pytest.mark.parametrize("use_eval", [False, True]) +def test_callback_uses_correct_dataloader(monkeypatch, use_eval): + trainer = DummyTrainer(use_eval=use_eval) + callback = DiffusionGenerationCallback(trainer) + + captured = {} + + # Patch generate_samples in the callback module's namespace + def fake_generate_samples(**kwargs): + captured["dataloader"] = kwargs.get("dataloader") + # Return one dummy sample to exercise logging path + return [ + { + "original": "o", + "masked": "m", + "generated": "g", + "mask_ratio": 0.5, + "masked_tokens": 1, + "total_tokens": 2, + } + ] + + monkeypatch.setattr( + "axolotl.integrations.diffusion.callbacks.generate_samples", + fake_generate_samples, + ) + + # Trigger at step 1 (interval=1) + args = SimpleNamespace() + state = SimpleNamespace(global_step=1) + control = SimpleNamespace() + + callback.on_step_end(args=args, state=state, control=control) + + # Assert the expected dataloader path was used + if use_eval: + assert trainer.requested[0] == "eval" + assert captured["dataloader"] is trainer._eval_loader + else: + assert trainer.requested[0] == "train" + assert captured["dataloader"] is trainer._train_loader diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 54acbb5e4..2c1f9f936 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,12 +5,12 @@ from unittest.mock import Mock, patch from datasets import IterableDataset -from axolotl.utils.dict import DictDefault +from axolotl.utils.config import validate_config from axolotl.utils.data.sft import ( _prepare_streaming_dataset, prepare_datasets, ) -from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault class TestStreamingConfig(unittest.TestCase):