add back in reinit_weights (clobbered?); masking / pretrain fixes

This commit is contained in:
Dan Saunders
2025-08-15 02:21:25 +00:00
parent 479a454ae3
commit e19be0c2d9
5 changed files with 139 additions and 111 deletions

View File

@@ -34,7 +34,7 @@ lr_scheduler: cosine
learning_rate: 3e-4
bf16: auto
tf32: false
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
@@ -51,8 +51,8 @@ eval_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project: diffusion-plugin
wandb_entity: axolotl-ai
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -82,7 +82,9 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -573,9 +575,26 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
# Add reduced stored metrics to logs
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"])
reduction_type = metric_data["reduction"]
if reduction_type == "mean":
logs[key] = values.mean().item()
elif reduction_type == "min":
logs[key] = values.min().item()
elif reduction_type == "max":
logs[key] = values.max().item()
elif reduction_type == "sum":
logs[key] = values.sum().item()
else:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max]"
)
logs[key] = round(logs[key], 4)
if is_main_process():
# Add memory usage
@@ -592,10 +611,27 @@ class AxolotlTrainer(
return super().log(logs, start_time)
def store_metrics(
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
self,
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
if isinstance(value, tuple):
metric_value, metric_reduction = value
else:
metric_value, metric_reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(metric_value)
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey

View File

@@ -1,5 +1,7 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
@@ -16,14 +18,35 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = None
self._config = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.config = config
self._config = config
self._cache_special_token_ids()
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(model, input_ids, attention_mask)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
@@ -42,7 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
self._special_token_ids = special_tokens
def forward_process(
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
@@ -90,14 +113,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
masked_indices = masked_indices & attention_mask.bool()
# Get mask token ID from config
mask_token_id = self.config.mask_token_id
mask_token_id = self._config.mask_token_id
# Create masked input using configured mask token
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def create_bidirectional_attention_mask(
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
@@ -115,7 +138,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
if attention_mask is None or not self._config.sample_packing:
# Simple case: no attention mask, allow all-to-all attention
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
@@ -133,12 +156,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
return bidirectional_mask
def compute_diffusion_loss(
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
@@ -152,12 +175,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
metrics: Dictionary of metrics.
"""
# Apply forward process
noisy_batch, masked_indices, p_mask = self.forward_process(
input_ids, attention_mask, self.config.eps
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, self._config.eps
)
# Create bidirectional attention mask
bidirectional_mask = self.create_bidirectional_attention_mask(
bidirectional_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
@@ -187,7 +210,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
)
# Apply importance weighting if enabled
if self.config.importance_weighting:
if self._config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
@@ -211,40 +234,15 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": masked_indices.sum().item(),
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
"avg_p_mask": p_mask[masked_indices].mean().item(),
"ce_loss": ce_loss.item(),
}
if self.config.importance_weighting:
if self._config.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
return loss, metrics
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
# # Log metrics
# if self.state.is_local_process_zero:
# for key, value in metrics.items():
# self.log({f"train/diffusion_{key}": value})
if return_outputs:
# TODO: compute outputs (?)
outputs = [loss]
return (loss, outputs)
return loss
return loss, outputs

View File

@@ -667,6 +667,23 @@ class ModelLoader:
return hf_ds_cfg
def _load_model_from_config(self) -> PreTrainedModel:
"""Load model with random initialization using from_config."""
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
return self.auto_model_loader.from_config(config=self.model_config)
return self.auto_model_loader(config=self.model_config)
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
"""Load model from pretrained weights."""
loader = model_loader_class or self.auto_model_loader
kwargs = {
**self.model_kwargs,
"config": self.model_config,
"trust_remote_code": self.cfg.trust_remote_code or False,
**self.model_kwargs,
}
return loader.from_pretrained(self.base_model, **kwargs)
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
@@ -681,7 +698,8 @@ class ModelLoader:
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
# _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
@@ -710,6 +728,11 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading
)
):
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with sharded quantized loading. "
"Loading from pretrained weights instead."
)
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
@@ -725,33 +748,12 @@ class ModelLoader:
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
elif self.model_type == "MambaLMHeadModel":
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with MambaLMHeadModel. "
"Loading from pretrained weights instead."
)
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
@@ -764,41 +766,27 @@ class ModelLoader:
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
# Please don't remove underscore binding without reading the fn docstring.
# Please don't remove underscore binding without reading the fn docstring
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Use model type from transformers
model_loader_class = getattr(transformers, self.model_type)
else:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
if self.cfg.reinit_weights:
self.model = self._load_model_from_config()
else:
self.model = self._load_model_from_pretrained(model_loader_class)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True

View File

@@ -109,6 +109,12 @@ class AxolotlInputConfig(
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
},
)
reinit_weights: bool | None = Field(
default=None,
json_schema_extra={
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
},
)
trainer_cls: str | None = Field(
default=None,