Diffusion trainer fix: shift logits to align with input tokens (#3191)
* shift logits for diffusion generate * delete unused * diffusion trainer: token shift
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .utils import create_bidirectional_attention_mask
|
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||||
logits = outputs.logits
|
logits = shift_logits_to_input_positions(outputs.logits)
|
||||||
|
|
||||||
# Only sample at currently masked positions
|
# Only sample at currently masked positions
|
||||||
if current_mask.any():
|
if current_mask.any():
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .callbacks import DiffusionGenerationCallback
|
from .callbacks import DiffusionGenerationCallback
|
||||||
from .utils import create_bidirectional_attention_mask
|
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
input_ids=noisy_batch.long(),
|
input_ids=noisy_batch.long(),
|
||||||
attention_mask=bidirectional_mask,
|
attention_mask=bidirectional_mask,
|
||||||
)
|
)
|
||||||
logits = outputs.logits
|
logits = shift_logits_to_input_positions(outputs.logits)
|
||||||
|
|
||||||
if masked_indices.sum() > 0:
|
if masked_indices.sum() > 0:
|
||||||
valid_indices = torch.where(masked_indices)
|
valid_indices = torch.where(masked_indices)
|
||||||
|
|||||||
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
|
|||||||
|
|
||||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||||
return bidirectional_mask.unsqueeze(1)
|
return bidirectional_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Align next-token logits with their input token positions for diffusion."""
|
||||||
|
if logits.size(1) <= 1:
|
||||||
|
return logits
|
||||||
|
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user