diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml index 1e17a0ea1..8a7fe3f9c 100644 --- a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -2,29 +2,27 @@ base_model: meta-llama/Llama-3.2-1B # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name -# Dataset configuration for pretraining -datasets: +pretraining_dataset: - path: wikitext name: wikitext-103-raw-v1 type: completion field: text -val_set_size: 0.001 plugins: - diffusion.DiffusionPlugin -noise_schedule: "cosine" +noise_schedule: cosine min_mask_ratio: 0.15 max_mask_ratio: 0.85 -num_diffusion_steps: 128 eps: 5e-4 importance_weighting: true mask_token_id: 128002 +generate_samples: true +generation_interval: 10 output_dir: ./outputs/model-out sequence_len: 512 -sample_packing: false -eval_sample_packing: false +sample_packing: true gradient_accumulation_steps: 8 micro_batch_size: 4 @@ -42,12 +40,10 @@ resume_from_checkpoint: logging_steps: 1 sdp_attention: true -warmup_steps: 500 +warmup_steps: 1000 save_strategy: steps -eval_strategy: steps save_steps: 1000 -eval_steps: 1000 special_tokens: pad_token: "<|end_of_text|>" diff --git a/examples/llama-3/diffusion-3.2-1b-sft.yaml b/examples/llama-3/diffusion-3.2-1b-sft.yaml index af00ac9fd..1f5fa263a 100644 --- a/examples/llama-3/diffusion-3.2-1b-sft.yaml +++ b/examples/llama-3/diffusion-3.2-1b-sft.yaml @@ -9,7 +9,7 @@ val_set_size: 0.05 plugins: - diffusion.DiffusionPlugin -noise_schedule: "linear" +noise_schedule: cosine min_mask_ratio: 0.1 max_mask_ratio: 0.9 num_diffusion_steps: 128 @@ -39,6 +39,8 @@ resume_from_checkpoint: logging_steps: 1 sdp_attention: true +warmup_steps: 1000 + save_strategy: steps eval_strategy: steps save_steps: 500 diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py index 1f6263e47..7d49c51b0 100644 --- a/src/axolotl/integrations/diffusion/args.py +++ b/src/axolotl/integrations/diffusion/args.py @@ -27,8 +27,6 @@ class DiffusionArgs(BaseModel): num_diffusion_steps: int = Field( default=128, ge=1, description="Number of diffusion timesteps" ) - - # Forward process config eps: float = Field( default=1e-3, ge=0.0, diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py index fafd99919..595bf49a1 100644 --- a/src/axolotl/integrations/diffusion/callbacks.py +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -26,26 +26,24 @@ class DiffusionGenerationCallback(TrainerCallback): **kwargs, ): """Generate samples at specified intervals.""" - # Only generate samples at the specified interval and after step 0 if ( state.global_step > 0 and state.global_step % self.trainer.config.generation_interval == 0 - and hasattr(self.trainer, "eval_dataset") - and self.trainer.eval_dataset is not None ): - - LOG.info( - f"Generating {self.trainer.config.num_generation_samples} samples at step {state.global_step}..." - ) - - # Create a simple dataloader from eval dataset for sampling - eval_dataloader = self.trainer.get_eval_dataloader() + # Use eval dataloader if available, otherwise use train dataloader + if ( + hasattr(self.trainer, "eval_dataset") + and self.trainer.eval_dataset is not None + ): + dataloader = self.trainer.callback_handler.eval_dataloader + else: + dataloader = self.trainer.callback_handler.train_dataloader # Generate samples samples = generate_samples( model=self.trainer.model, tokenizer=self.trainer.tokenizer, - val_dataloader=eval_dataloader, + dataloader=dataloader, num_generation_samples=self.trainer.config.num_generation_samples, max_length=self.trainer.config.generation_max_length, num_diffusion_steps=self.trainer.config.generation_steps, diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py index d019ec023..2b07121a2 100644 --- a/src/axolotl/integrations/diffusion/generation.py +++ b/src/axolotl/integrations/diffusion/generation.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) def generate_samples( model: torch.nn.Module, tokenizer: Any, - val_dataloader: Optional[Any] = None, + dataloader: Optional[Any] = None, num_generation_samples: int = 3, max_length: int = 100, num_diffusion_steps: int = 128, @@ -19,13 +19,13 @@ def generate_samples( mask_token_id: int = 32000, ) -> List[dict]: """ - Generate text samples using the diffusion model by randomly masking sequences - from the validation dataset and running the reverse diffusion process. + Generate text samples using the diffusion model by randomly masking sequences from + the given dataset and running the reverse diffusion process. Args: model: The wrapped or unwrapped model tokenizer: Tokenizer for encoding/decoding - val_dataloader: Validation dataloader (for sampling sequences) + dataloader: Validation dataloader (for sampling sequences) num_generation_samples: Number of samples to generate max_length: Maximum length of sequences to use num_diffusion_steps: Number of diffusion steps for generation @@ -35,7 +35,7 @@ def generate_samples( Returns: List of dictionaries with original text, masked text, and generated text """ - if val_dataloader is None: + if dataloader is None: logger.warning("No validation dataloader provided, cannot generate samples") return [] @@ -46,7 +46,7 @@ def generate_samples( # Sample sequences from validation dataset sampled_sequences = _sample_sequences_from_dataloader( - val_dataloader, num_generation_samples, max_length, unwrapped_model.device + dataloader, num_generation_samples, max_length, unwrapped_model.device ) logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") @@ -68,7 +68,7 @@ def generate_samples( def _sample_sequences_from_dataloader( - val_dataloader: Any, num_samples: int, max_length: int, device: torch.device + dataloader: Any, num_samples: int, max_length: int, device: torch.device ) -> List[torch.Tensor]: """Sample sequences from validation dataloader.""" sampled_sequences = [] @@ -78,7 +78,7 @@ def _sample_sequences_from_dataloader( skip_batches = torch.randint(0, 6, (1,)).item() batch_count = 0 - for batch in val_dataloader: + for batch in dataloader: # Skip some batches for variety if batch_count < skip_batches: batch_count += 1 @@ -183,13 +183,15 @@ def _generate( def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: """Clean up masked text for display.""" - # Get the mask token representation from the tokenizer mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) cleaned = masked_text.replace(mask_token_repr, "[MASK]") - # Clean up special tokens and whitespace - cleaned = cleaned.replace("", "").replace("", "").strip() - cleaned = " ".join(cleaned.split()) + if hasattr(tokenizer, "special_tokens_map"): + for token_value in tokenizer.special_tokens_map.values(): + if token_value and isinstance(token_value, str): + cleaned = cleaned.replace(token_value, "") + + cleaned = " ".join(cleaned.split()).strip() return cleaned diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index be1bb9838..dc62035d5 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -270,13 +270,6 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors "avg_p_mask": p_mask[masked_indices].mean().item(), "ce_loss": ce_loss.item(), } - - # Add SFT-specific metrics - if labels is not None: - answer_mask = labels != -100 - metrics["answer_ratio"] = answer_mask.float().mean().item() - metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item() - if self.config.importance_weighting: metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 9ca645de3..cb0d32d4a 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC): ) -> BatchEncoding: empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) if not prompt: - LOG.warning("Empty text requested for tokenization.") + LOG.warning_once("Empty text requested for tokenization.") return empty result = self.tokenizer(