Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
cf8c93e2ee wip 2025-08-19 09:36:57 -04:00
12 changed files with 544 additions and 781 deletions

View File

@@ -274,6 +274,18 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index, rank=self.args.process_index,
) )
if (self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and ( if self.args.sample_packing and (
(is_training and not self.args.pretraining) (is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False) or (not is_training and self.args.eval_sample_packing is not False)

View File

@@ -64,25 +64,11 @@ learning_rate: 3e-4
## Supported Models ## Supported Models
Currently supported base model types: Any models that support 4D attention masks should work out of the box. If not, please
- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM` create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues)!
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
## How It Works ## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking ### Random Masking
During training, tokens are randomly masked based on a sampled timestep: During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1] - Sample timestep `t` uniformly from [0, 1]
@@ -90,10 +76,11 @@ During training, tokens are randomly masked based on a sampled timestep:
- Randomly mask tokens with probability `p` - Randomly mask tokens with probability `p`
### Bidirectional Attention ### Bidirectional Attention
The models override causal attention with bidirectional attention: The plugin uses native 4D attention masks to:
- Creates 4D attention masks allowing all-to-all attention - Enable bidirectional attention without patches
- Maintains proper padding and sample packing masks - Allow all tokens to attend to all other tokens
- Compatible with standard HuggingFace attention implementations - Maintain proper padding masks
- Work with modern `transformers` models out of the box
### Diffusion Loss ### Diffusion Loss
@@ -103,22 +90,6 @@ Loss is computed only on masked tokens with (optional) importance weighting:
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
``` ```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation ## Sample Generation
When `generate_samples: true`, the plugin generates samples during training: When `generate_samples: true`, the plugin generates samples during training:
@@ -144,19 +115,9 @@ The plugin adds several metrics to track diffusion training:
- `train/ce_loss`: Unweighted cross-entropy loss - `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight - `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations ## Limitations
- **Model Support**: Currently limited to Llama and Mistral architectures - No flash attention support
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
## References ## References

View File

@@ -1,26 +1,6 @@
"""Diffusion LM training plugin init.""" """Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin from .plugin import DiffusionPlugin
# Register custom configurations __all__ = ["DiffusionArgs", "DiffusionPlugin"]
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]

View File

@@ -26,31 +26,29 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs, **kwargs,
): ):
"""Generate samples at specified intervals.""" """Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if ( if (
state.global_step > 0 state.global_step > 0
and state.global_step % config.get('generation_interval', 100) == 0 and state.global_step % self.trainer.config.generation_interval == 0
): ):
# Use eval dataloader if available, otherwise use train dataloader # Use eval dataloader if available, otherwise use train dataloader
if ( if (
hasattr(self.trainer, "eval_dataset") hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None and self.trainer.eval_dataset is not None
): ):
dataloader = self.trainer.get_eval_dataloader() dataloader = self.trainer.callback_handler.eval_dataloader
else: else:
dataloader = self.trainer.get_train_dataloader() dataloader = self.trainer.callback_handler.train_dataloader
# Generate samples # Generate samples
samples = generate_samples( samples = generate_samples(
model=self.trainer.model, model=self.trainer.model,
tokenizer=self.trainer.tokenizer, tokenizer=self.trainer.tokenizer,
dataloader=dataloader, dataloader=dataloader,
num_generation_samples=config.get('num_generation_samples', 3), num_generation_samples=self.trainer.config.num_generation_samples,
max_length=config.get('generation_max_length', 256), max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=config.get('generation_steps', 10), num_diffusion_steps=self.trainer.config.generation_steps,
temperature=config.get('generation_temperature', 1.0), temperature=self.trainer.config.generation_temperature,
mask_token_id=config.get('mask_token_id', 32000), mask_token_id=self.trainer.config.mask_token_id,
) )
# Log samples # Log samples
@@ -83,8 +81,7 @@ class DiffusionGenerationCallback(TrainerCallback):
LOG.info("=" * 60) LOG.info("=" * 60)
config = getattr(self.trainer, 'diffusion_config', self.trainer.args) if self.trainer.config.use_wandb and self.trainer.state.is_world_process_zero:
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if wandb.run is not None: if wandb.run is not None:
wandb.log( wandb.log(
{ {

View File

@@ -1,71 +0,0 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,426 +0,0 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,20 +1,13 @@
"""Diffusion LM training plugin for Axolotl.""" """Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel from peft import PeftModel
from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback from .trainer import DiffusionTrainer
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -35,64 +28,14 @@ class DiffusionPlugin(BasePlugin):
"""Returns the pydantic model for LLaDA plugin arguments.""" """Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs" return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model after loading.""" """Perform actions after model is loaded."""
self.cfg = cfg self.cfg = cfg
# Set tokenizer on diffusion models for special token handling
if hasattr(model, "set_tokenizer"):
# Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"): def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
"""Add diffusion-specific callbacks after trainer creation.""" """Return custom trainer class for diffusion training."""
callbacks = [] return DiffusionTrainer
# Store diffusion config on trainer for callbacks def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
trainer.diffusion_config = cfg """Configure trainer after creation."""
trainer.set_config(cfg)
# Add generation callback if enabled
if cfg.get("generate_samples", False):
generation_callback = DiffusionGenerationCallback(trainer)
callbacks.append(generation_callback)
return callbacks

View File

@@ -0,0 +1,336 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from transformers.masking_utils import find_packed_sequence_indices
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.diffusion.utils import create_bidirectional_block_mask
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._cache_special_token_ids()
if config.generate_samples:
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
position_ids = inputs.get("position_ids")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels, position_ids
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
@torch.compile
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
min_p: float = 0.0,
max_p: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = min_p + (max_p - min_p) * (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
@torch.compile
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
position_ids: Position ids [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
if position_ids is None:
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
if self._config.flex_attention:
block_mask = create_bidirectional_block_mask(
input_ids, attention_mask, position_ids
)
else:
packed_seq_mask = find_packed_sequence_indices(position_ids)
block_mask = packed_seq_mask.unsqueeze(2) == packed_seq_mask.unsqueeze(1)
return block_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
position_ids: Position ids [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self._config.eps, self._config.min_mask_ratio, self._config.max_mask_ratio
)
# Create bidirectional attention mask (optional: use causal if you want strict AR behavior)
bidirectional_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask, position_ids
)
# Forward pass
outputs = model(
input_ids=noisy_batch,
attention_mask=bidirectional_mask,
)
logits = outputs.logits # [B, L, V]
# ----- AR label shift toggle -----
use_ar_shift = False
if use_ar_shift:
# Predict token at t from logits at t-1: drop last logit step, drop first target step
logits_eff = logits[:, :-1, :]
input_ids_eff = input_ids[:, 1:]
masked_indices_eff = masked_indices[:, 1:]
p_mask_eff = p_mask[:, 1:]
labels_eff = labels[:, 1:] if labels is not None else None
else:
logits_eff = logits
input_ids_eff = input_ids
masked_indices_eff = masked_indices
p_mask_eff = p_mask
labels_eff = labels
if masked_indices_eff.sum() > 0:
valid_indices = torch.where(masked_indices_eff)
batch_indices, seq_indices = valid_indices
masked_logits = logits_eff[batch_indices, seq_indices]
masked_targets = input_ids_eff[batch_indices, seq_indices]
masked_p_mask = p_mask_eff[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float().clamp_min(1e-6)
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels_eff is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels_eff != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.any():
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
# Keep eff tensors around for metrics
masked_indices_eff = masked_indices
p_mask_eff = p_mask
labels_eff = labels
# Metrics (aligned to the effective tensors)
if masked_indices_eff.any():
avg_p = p_mask_eff[masked_indices_eff].float().mean().item()
num_masked = int(masked_indices_eff.sum().item())
mask_ratio = masked_indices_eff.float().mean().item()
else:
avg_p = 0.0
num_masked = 0
mask_ratio = 0.0
metrics = {
"loss": float(loss.detach()),
"accuracy": float(accuracy.detach()),
"mask_ratio": mask_ratio,
"num_masked_tokens": (num_masked, "sum"),
"avg_p_mask": avg_p,
"ce_loss": float(ce_loss.detach()),
}
# SFT-specific metrics (aligned)
if labels_eff is not None:
answer_mask = labels_eff != -100
metrics["answer_ratio"] = answer_mask.float().mean().item()
metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item()
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -0,0 +1,50 @@
import torch
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers.masking_utils import find_packed_sequence_indices, packed_sequence_mask_function
def create_bidirectional_block_mask(
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
) -> "BlockMask":
"""
Creates a bidirectional block mask for FlexAttention.
Args:
input_ids: Input token ids [batch_size, seq_len]
attention_mask: Padding mask [batch_size, seq_len]
Returns:
BlockMask for bidirectional attention with padding
"""
batch_size, seq_len = input_ids.shape
if position_ids is not None:
packed_seq_mask = find_packed_sequence_indices(position_ids)
mask_fn =packed_sequence_mask_function(packed_seq_mask, batch_size, seq_len)
elif attention_mask is None:
# If no padding mask, all positions can attend to all positions
def mask_fn(b, h, q_idx, kv_idx):
# Always return True for bidirectional attention
return True
else:
# Convert attention_mask to boolean if needed
attention_mask = attention_mask.bool()
def mask_fn(b, h, q_idx, kv_idx):
# Both query and key positions must be valid (not padding)
return attention_mask[b, q_idx] & attention_mask[b, kv_idx]
# Create the block mask
block_mask = create_block_mask(
mask_fn,
B=batch_size,
H=None, # Will be set by the attention layer
Q_LEN=seq_len,
KV_LEN=seq_len,
device=input_ids.device,
_compile=True,
)
return block_mask

View File

@@ -57,7 +57,7 @@ class SpectrumPlugin(BasePlugin):
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data. Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
""" """
base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/" base_url = "https://raw.githubusercontent.com/QuixiAI/spectrum/main/model_snr_results/"
base_path = "./model_snr_results/" base_path = "./model_snr_results/"
snr_file_template = "snr_results_{model_name_slug}.json" snr_file_template = "snr_results_{model_name_slug}.json"

View File

@@ -16,7 +16,7 @@ from packaging.version import Version, parse
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support():
return False return False
unsupported_devices = {"RTX 6000 Ada", "L40S"} unsupported_devices = {"RTX 6000 Ada", "L40S", "A40"}
try: try:
device_names, device_count = get_gpu_info() device_names, device_count = get_gpu_info()
if 1 < device_count < 8: if 1 < device_count < 8:

View File

@@ -1,14 +1,13 @@
"""Tests for diffusion model integration.""" """Tests for diffusion trainer integration."""
# pylint: disable=redefined-outer-name,protected-access # pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock, patch from unittest.mock import Mock
import pytest import pytest
import torch import torch
from axolotl.integrations.diffusion.configuration import LlamaForDiffusionConfig from axolotl.integrations.diffusion.trainer import DiffusionTrainer
from axolotl.integrations.diffusion.models import LlamaForDiffusionLM
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -25,44 +24,37 @@ def mock_tokenizer():
@pytest.fixture @pytest.fixture
def diffusion_config(): def diffusion_config():
"""Create a diffusion config.""" """Create a diffusion config."""
return LlamaForDiffusionConfig( return DictDefault(
mask_token_id=32000, {
eps=1e-3, "mask_token_id": 32000,
importance_weighting=False, "eps": 1e-3,
sample_packing=False, "importance_weighting": False,
# Basic llama config fields - smaller for testing "sample_packing": False,
vocab_size=1000, }
hidden_size=256,
intermediate_size=512,
num_hidden_layers=2,
num_attention_heads=4,
) )
@pytest.fixture @pytest.fixture
def diffusion_model_instance(mock_tokenizer, diffusion_config): def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion model instance for testing methods directly.""" """Create a diffusion trainer instance for testing methods directly."""
# Create a minimal model instance for testing # Create a minimal trainer instance just for testing methods
model = object.__new__(LlamaForDiffusionLM) trainer = object.__new__(DiffusionTrainer) # Bypass __init__
model.config = diffusion_config trainer.config = diffusion_config
model._special_token_ids = {0, 1, 2} # pad, bos, eos trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
model.training = True trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
# Set tokenizer return trainer
model.set_tokenizer(mock_tokenizer)
return model
class TestDiffusionModel: class TestDiffusionTrainer:
"""Test the DiffusionModel class.""" """Test the DiffusionTrainer class."""
def test_forward_process_basic(self, diffusion_model_instance): def test_forward_process_basic(self, diffusion_trainer_instance):
"""Test basic forward process without labels.""" """Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = ( noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process(input_ids, eps=0.1) diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
) )
# Check shapes # Check shapes
@@ -75,18 +67,18 @@ class TestDiffusionModel:
assert not masked_indices[special_token_positions].any() assert not masked_indices[special_token_positions].any()
# Check that mask token is applied # Check that mask token is applied
mask_token_id = diffusion_model_instance.config.mask_token_id mask_token_id = diffusion_trainer_instance._config.mask_token_id
masked_positions = masked_indices masked_positions = masked_indices
if masked_positions.any(): if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all() assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_model_instance): def test_forward_process_with_labels(self, diffusion_trainer_instance):
"""Test forward process with SFT labels.""" """Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = ( noisy_batch, masked_indices, p_mask = (
diffusion_model_instance._forward_process( diffusion_trainer_instance._forward_process(
input_ids, labels=labels, eps=0.1 input_ids, labels=labels, eps=0.1
) )
) )
@@ -108,12 +100,12 @@ class TestDiffusionModel:
# Verify that masked_indices respects the answer mask # Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any() assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_model_instance): def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
"""Test forward process with attention mask.""" """Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_model_instance._forward_process( _, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1 input_ids, attention_mask=attention_mask, eps=0.1
) )
@@ -122,11 +114,11 @@ class TestDiffusionModel:
assert not masked_indices[padding_positions].any() assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all() assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_model_instance): def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
"""Test bidirectional attention mask without sample packing.""" """Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask( mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids input_ids
) )
@@ -136,15 +128,15 @@ class TestDiffusionModel:
assert mask.all() assert mask.all()
def test_bidirectional_attention_mask_with_packing( def test_bidirectional_attention_mask_with_packing(
self, diffusion_model_instance self, diffusion_trainer_instance
): ):
"""Test bidirectional attention mask with sample packing.""" """Test bidirectional attention mask with sample packing."""
diffusion_model_instance.config.sample_packing = True diffusion_trainer_instance._config.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2) # Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = diffusion_model_instance._create_bidirectional_attention_mask( mask = diffusion_trainer_instance._create_bidirectional_attention_mask(
input_ids, attention_mask input_ids, attention_mask
) )
@@ -156,135 +148,124 @@ class TestDiffusionModel:
assert not mask[0, 0, 2, 4].item() assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_model_instance): def test_compute_loss_basic(self, diffusion_trainer_instance):
"""Test basic loss computation.""" """Test basic loss computation."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) # Mock model that returns logits
mock_model = Mock()
# Create mock data for loss computation mock_outputs = Mock()
vocab_size = 1000 vocab_size = 1000
seq_len = 5 seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
# Create a simple masked indices tensor (mask middle tokens) mock_model.training = True
masked_indices = torch.tensor([[False, True, True, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.5, 0.5, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss( input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
input_ids=input_ids,
logits=logits, loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
masked_indices=masked_indices, mock_model, input_ids
p_mask=p_mask,
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
assert outputs == mock_outputs
def test_compute_loss_with_labels(self, diffusion_model_instance): # Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_with_labels(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels.""" """Test loss computation with SFT labels."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
# Create mock data for loss computation
vocab_size = 1000
seq_len = 5
logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
# Create masked indices that only covers answer tokens
masked_indices = torch.tensor([[False, False, True, True, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.5, 0.5, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss( loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
input_ids=input_ids, mock_model, input_ids, labels=labels
labels=labels,
logits=logits,
masked_indices=masked_indices,
p_mask=p_mask,
) )
# Check that loss is computed # Check that loss is computed
assert isinstance(loss, torch.Tensor) assert isinstance(loss, torch.Tensor)
assert loss.requires_grad assert loss.requires_grad
def test_compute_loss_no_masked_tokens(self, diffusion_model_instance): # Check that SFT metrics were added
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked.""" """Test loss computation when no tokens are masked."""
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) # Mock model
mock_model = Mock()
# Create mock data for loss computation mock_outputs = Mock()
vocab_size = 1000 vocab_size = 1000
seq_len = 3 seq_len = 3
logits = torch.randn(1, seq_len, vocab_size) mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
# No tokens masked mock_model.training = True
masked_indices = torch.tensor([[False, False, False]], dtype=torch.bool)
p_mask = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float)
loss = diffusion_model_instance._compute_diffusion_loss( # Only special tokens (which won't be masked)
input_ids=input_ids, input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
logits=logits,
masked_indices=masked_indices, loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
p_mask=p_mask, mock_model, input_ids
) )
# Loss should be zero when no tokens are masked # Loss should be zero when no tokens are masked
assert loss.item() == 0.0 assert loss.item() == 0.0
assert loss.requires_grad assert loss.requires_grad
def test_cache_special_token_ids(self, diffusion_model_instance): def test_cache_special_token_ids(self, diffusion_trainer_instance):
"""Test caching of special token IDs.""" """Test caching of special token IDs."""
# Should cache BOS, EOS, PAD tokens # Should cache BOS, EOS, PAD tokens
expected_tokens = {0, 1, 2} # pad, bos, eos expected_tokens = {0, 1, 2} # pad, bos, eos
assert diffusion_model_instance._special_token_ids == expected_tokens assert diffusion_trainer_instance._special_token_ids == expected_tokens
def test_cache_special_token_ids_no_tokenizer(self): def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available.""" """Test caching when no tokenizer is available."""
# Mock the parent model initialization to avoid loading pretrained weights trainer = object.__new__(DiffusionTrainer) # Bypass __init__
with patch('transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__'): trainer.processing_class = None
model = LlamaForDiffusionLM.__new__(LlamaForDiffusionLM) trainer._cache_special_token_ids()
model._cache_special_token_ids(None)
assert model._special_token_ids == set()
def test_forward_training_mode(self, diffusion_model_instance): assert trainer._special_token_ids == set()
"""Test forward pass in training mode."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.bool)
# Mock the parent forward method
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward:
mock_output = Mock()
mock_output.logits = torch.randn(1, 5, 32000)
mock_forward.return_value = mock_output
# Set training mode
diffusion_model_instance.training = True
result = diffusion_model_instance.forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Should call parent forward and compute loss
assert mock_forward.called
assert hasattr(result, 'loss')
def test_forward_inference_mode(self, diffusion_model_instance): def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test forward pass in inference mode.""" """Test the main compute_loss interface."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) # Mock model
mock_model = Mock()
# Mock the parent forward method mock_outputs = Mock()
with patch.object(diffusion_model_instance.__class__.__bases__[1], 'forward') as mock_forward: mock_outputs.logits = torch.randn(1, 5, 1000)
mock_output = Mock() mock_model.return_value = mock_outputs
mock_forward.return_value = mock_output mock_model.training = True
# Set inference mode inputs = {
diffusion_model_instance.training = False "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
result = diffusion_model_instance.forward( "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
input_ids=input_ids, }
return_dict=True
) # Test without return_outputs
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
# Should just call parent forward without diffusion processing assert isinstance(loss, torch.Tensor)
assert mock_forward.called
assert result == mock_output # Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError."""
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs)