Diffusion trainer fix: shift logits to align with input tokens (#3191)
* shift logits for diffusion generate * delete unused * diffusion trainer: token shift
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
|
||||
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
|
||||
|
||||
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Align next-token logits with their input token positions for diffusion."""
|
||||
if logits.size(1) <= 1:
|
||||
return logits
|
||||
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
Reference in New Issue
Block a user