This commit is contained in:
Dan Saunders
2025-08-18 19:09:09 +00:00
parent 556a69118f
commit b210db2d15
7 changed files with 33 additions and 44 deletions

View File

@@ -2,29 +2,27 @@ base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Dataset configuration for pretraining
datasets:
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
val_set_size: 0.001
plugins:
- diffusion.DiffusionPlugin
noise_schedule: "cosine"
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 128
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 10
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: false
eval_sample_packing: false
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
@@ -42,12 +40,10 @@ resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 500
warmup_steps: 1000
save_strategy: steps
eval_strategy: steps
save_steps: 1000
eval_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -9,7 +9,7 @@ val_set_size: 0.05
plugins:
- diffusion.DiffusionPlugin
noise_schedule: "linear"
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
@@ -39,6 +39,8 @@ resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
eval_strategy: steps
save_steps: 500

View File

@@ -27,8 +27,6 @@ class DiffusionArgs(BaseModel):
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
# Forward process config
eps: float = Field(
default=1e-3,
ge=0.0,

View File

@@ -26,26 +26,24 @@ class DiffusionGenerationCallback(TrainerCallback):
**kwargs,
):
"""Generate samples at specified intervals."""
# Only generate samples at the specified interval and after step 0
if (
state.global_step > 0
and state.global_step % self.trainer.config.generation_interval == 0
and hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
LOG.info(
f"Generating {self.trainer.config.num_generation_samples} samples at step {state.global_step}..."
)
# Create a simple dataloader from eval dataset for sampling
eval_dataloader = self.trainer.get_eval_dataloader()
# Use eval dataloader if available, otherwise use train dataloader
if (
hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
dataloader = self.trainer.callback_handler.eval_dataloader
else:
dataloader = self.trainer.callback_handler.train_dataloader
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
val_dataloader=eval_dataloader,
dataloader=dataloader,
num_generation_samples=self.trainer.config.num_generation_samples,
max_length=self.trainer.config.generation_max_length,
num_diffusion_steps=self.trainer.config.generation_steps,

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
val_dataloader: Optional[Any] = None,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
@@ -19,13 +19,13 @@ def generate_samples(
mask_token_id: int = 32000,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences
from the validation dataset and running the reverse diffusion process.
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
val_dataloader: Validation dataloader (for sampling sequences)
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
@@ -35,7 +35,7 @@ def generate_samples(
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if val_dataloader is None:
if dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples")
return []
@@ -46,7 +46,7 @@ def generate_samples(
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
val_dataloader, num_generation_samples, max_length, unwrapped_model.device
dataloader, num_generation_samples, max_length, unwrapped_model.device
)
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
@@ -68,7 +68,7 @@ def generate_samples(
def _sample_sequences_from_dataloader(
val_dataloader: Any, num_samples: int, max_length: int, device: torch.device
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader."""
sampled_sequences = []
@@ -78,7 +78,7 @@ def _sample_sequences_from_dataloader(
skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0
for batch in val_dataloader:
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
@@ -183,13 +183,15 @@ def _generate(
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
# Get the mask token representation from the tokenizer
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Clean up special tokens and whitespace
cleaned = cleaned.replace("<s>", "").replace("</s>", "").strip()
cleaned = " ".join(cleaned.split())
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
cleaned = " ".join(cleaned.split()).strip()
return cleaned

View File

@@ -270,13 +270,6 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(),
}
# Add SFT-specific metrics
if labels is not None:
answer_mask = labels != -100
metrics["answer_ratio"] = answer_mask.float().mean().item()
metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item()
if self.config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning("Empty text requested for tokenization.")
LOG.warning_once("Empty text requested for tokenization.")
return empty
result = self.tokenizer(