From 9d4d39e939b3e44298f0c5e1f1b05c7b515fc7a6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 27 Oct 2025 03:42:01 -0400 Subject: [PATCH] Diffusion trainer fix: shift logits to align with input tokens (#3191) * shift logits for diffusion generate * delete unused * diffusion trainer: token shift --- src/axolotl/integrations/diffusion/generation.py | 4 ++-- src/axolotl/integrations/diffusion/trainer.py | 4 ++-- src/axolotl/integrations/diffusion/utils.py | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py index 49e3cdfae..ec517fd23 100644 --- a/src/axolotl/integrations/diffusion/generation.py +++ b/src/axolotl/integrations/diffusion/generation.py @@ -7,7 +7,7 @@ import torch from axolotl.utils.logging import get_logger -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -360,7 +360,7 @@ def _diffusion_step( # Forward pass outputs = model(input_ids=sequence, attention_mask=attention_mask) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) # Only sample at currently masked positions if current_mask.any(): diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 42b2468f4..dfaef2a48 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from .callbacks import DiffusionGenerationCallback -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer): input_ids=noisy_batch.long(), attention_mask=bidirectional_mask, ) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) if masked_indices.sum() > 0: valid_indices = torch.where(masked_indices) diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py index 47abf6fec..b6f71c07b 100644 --- a/src/axolotl/integrations/diffusion/utils.py +++ b/src/axolotl/integrations/diffusion/utils.py @@ -157,3 +157,10 @@ def create_bidirectional_attention_mask( # Add head dimension: [batch_size, 1, seq_len, seq_len] return bidirectional_mask.unsqueeze(1) + + +def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor: + """Align next-token logits with their input token positions for diffusion.""" + if logits.size(1) <= 1: + return logits + return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)