This commit is contained in:
Dan Saunders
2025-08-18 19:09:09 +00:00
parent 556a69118f
commit b210db2d15
7 changed files with 33 additions and 44 deletions

View File

@@ -2,29 +2,27 @@ base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
# Dataset configuration for pretraining pretraining_dataset:
datasets:
- path: wikitext - path: wikitext
name: wikitext-103-raw-v1 name: wikitext-103-raw-v1
type: completion type: completion
field: text field: text
val_set_size: 0.001
plugins: plugins:
- diffusion.DiffusionPlugin - diffusion.DiffusionPlugin
noise_schedule: "cosine" noise_schedule: cosine
min_mask_ratio: 0.15 min_mask_ratio: 0.15
max_mask_ratio: 0.85 max_mask_ratio: 0.85
num_diffusion_steps: 128
eps: 5e-4 eps: 5e-4
importance_weighting: true importance_weighting: true
mask_token_id: 128002 mask_token_id: 128002
generate_samples: true
generation_interval: 10
output_dir: ./outputs/model-out output_dir: ./outputs/model-out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: true
eval_sample_packing: false
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
micro_batch_size: 4 micro_batch_size: 4
@@ -42,12 +40,10 @@ resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
sdp_attention: true sdp_attention: true
warmup_steps: 500 warmup_steps: 1000
save_strategy: steps save_strategy: steps
eval_strategy: steps
save_steps: 1000 save_steps: 1000
eval_steps: 1000
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"

View File

@@ -9,7 +9,7 @@ val_set_size: 0.05
plugins: plugins:
- diffusion.DiffusionPlugin - diffusion.DiffusionPlugin
noise_schedule: "linear" noise_schedule: cosine
min_mask_ratio: 0.1 min_mask_ratio: 0.1
max_mask_ratio: 0.9 max_mask_ratio: 0.9
num_diffusion_steps: 128 num_diffusion_steps: 128
@@ -39,6 +39,8 @@ resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
sdp_attention: true sdp_attention: true
warmup_steps: 1000
save_strategy: steps save_strategy: steps
eval_strategy: steps eval_strategy: steps
save_steps: 500 save_steps: 500

View File

@@ -27,8 +27,6 @@ class DiffusionArgs(BaseModel):
num_diffusion_steps: int = Field( num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps" default=128, ge=1, description="Number of diffusion timesteps"
) )
# Forward process config
eps: float = Field( eps: float = Field(
default=1e-3, default=1e-3,
ge=0.0, ge=0.0,

View File

@@ -26,26 +26,24 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs, **kwargs,
): ):
"""Generate samples at specified intervals.""" """Generate samples at specified intervals."""
# Only generate samples at the specified interval and after step 0
if ( if (
state.global_step > 0 state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0 and state.global_step % self.trainer.config.generation_interval == 0
and hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
): ):
# Use eval dataloader if available, otherwise use train dataloader
LOG.info( if (
f"Generating {self.trainer.config.num_generation_samples} samples at step {state.global_step}..." hasattr(self.trainer, "eval_dataset")
) and self.trainer.eval_dataset is not None
):
# Create a simple dataloader from eval dataset for sampling dataloader = self.trainer.callback_handler.eval_dataloader
eval_dataloader = self.trainer.get_eval_dataloader() else:
dataloader = self.trainer.callback_handler.train_dataloader
# Generate samples # Generate samples
samples = generate_samples( samples = generate_samples(
model=self.trainer.model, model=self.trainer.model,
tokenizer=self.trainer.tokenizer, tokenizer=self.trainer.tokenizer,
val_dataloader=eval_dataloader, dataloader=dataloader,
num_generation_samples=self.trainer.config.num_generation_samples, num_generation_samples=self.trainer.config.num_generation_samples,
max_length=self.trainer.config.generation_max_length, max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=self.trainer.config.generation_steps, num_diffusion_steps=self.trainer.config.generation_steps,

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
def generate_samples( def generate_samples(
model: torch.nn.Module, model: torch.nn.Module,
tokenizer: Any, tokenizer: Any,
val_dataloader: Optional[Any] = None, dataloader: Optional[Any] = None,
num_generation_samples: int = 3, num_generation_samples: int = 3,
max_length: int = 100, max_length: int = 100,
num_diffusion_steps: int = 128, num_diffusion_steps: int = 128,
@@ -19,13 +19,13 @@ def generate_samples(
mask_token_id: int = 32000, mask_token_id: int = 32000,
) -> List[dict]: ) -> List[dict]:
""" """
Generate text samples using the diffusion model by randomly masking sequences Generate text samples using the diffusion model by randomly masking sequences from
from the validation dataset and running the reverse diffusion process. the given dataset and running the reverse diffusion process.
Args: Args:
model: The wrapped or unwrapped model model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding tokenizer: Tokenizer for encoding/decoding
val_dataloader: Validation dataloader (for sampling sequences) dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation num_diffusion_steps: Number of diffusion steps for generation
@@ -35,7 +35,7 @@ def generate_samples(
Returns: Returns:
List of dictionaries with original text, masked text, and generated text List of dictionaries with original text, masked text, and generated text
""" """
if val_dataloader is None: if dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples") logger.warning("No validation dataloader provided, cannot generate samples")
return [] return []
@@ -46,7 +46,7 @@ def generate_samples(
# Sample sequences from validation dataset # Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader( sampled_sequences = _sample_sequences_from_dataloader(
val_dataloader, num_generation_samples, max_length, unwrapped_model.device dataloader, num_generation_samples, max_length, unwrapped_model.device
) )
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
@@ -68,7 +68,7 @@ def generate_samples(
def _sample_sequences_from_dataloader( def _sample_sequences_from_dataloader(
val_dataloader: Any, num_samples: int, max_length: int, device: torch.device dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader.""" """Sample sequences from validation dataloader."""
sampled_sequences = [] sampled_sequences = []
@@ -78,7 +78,7 @@ def _sample_sequences_from_dataloader(
skip_batches = torch.randint(0, 6, (1,)).item() skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0 batch_count = 0
for batch in val_dataloader: for batch in dataloader:
# Skip some batches for variety # Skip some batches for variety
if batch_count < skip_batches: if batch_count < skip_batches:
batch_count += 1 batch_count += 1
@@ -183,13 +183,15 @@ def _generate(
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display.""" """Clean up masked text for display."""
# Get the mask token representation from the tokenizer
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]") cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Clean up special tokens and whitespace if hasattr(tokenizer, "special_tokens_map"):
cleaned = cleaned.replace("<s>", "").replace("</s>", "").strip() for token_value in tokenizer.special_tokens_map.values():
cleaned = " ".join(cleaned.split()) if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
cleaned = " ".join(cleaned.split()).strip()
return cleaned return cleaned

View File

@@ -270,13 +270,6 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"avg_p_mask": p_mask[masked_indices].mean().item(), "avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(), "ce_loss": ce_loss.item(),
} }
# Add SFT-specific metrics
if labels is not None:
answer_mask = labels != -100
metrics["answer_ratio"] = answer_mask.float().mean().item()
metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item()
if self.config.importance_weighting: if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding: ) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt: if not prompt:
LOG.warning("Empty text requested for tokenization.") LOG.warning_once("Empty text requested for tokenization.")
return empty return empty
result = self.tokenizer( result = self.tokenizer(