add back in reinit_weights (clobbered?); masking / pretrain fixes
This commit is contained in:
@@ -34,7 +34,7 @@ lr_scheduler: cosine
|
||||
learning_rate: 3e-4
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
@@ -51,8 +51,8 @@ eval_steps: 1000
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project: diffusion-plugin
|
||||
wandb_entity: axolotl-ai
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -82,7 +82,9 @@ class AxolotlTrainer(
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
)
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@@ -573,9 +575,26 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
# Add reduced stored metrics to logs
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"])
|
||||
reduction_type = metric_data["reduction"]
|
||||
|
||||
if reduction_type == "mean":
|
||||
logs[key] = values.mean().item()
|
||||
elif reduction_type == "min":
|
||||
logs[key] = values.min().item()
|
||||
elif reduction_type == "max":
|
||||
logs[key] = values.max().item()
|
||||
elif reduction_type == "sum":
|
||||
logs[key] = values.sum().item()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max]"
|
||||
)
|
||||
|
||||
logs[key] = round(logs[key], 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
@@ -592,10 +611,27 @@ class AxolotlTrainer(
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self,
|
||||
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Store metrics with specified reduction type.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||
reduction_type) tuples.
|
||||
train_eval: Whether this is for training or evaluation.
|
||||
"""
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
if isinstance(value, tuple):
|
||||
metric_value, metric_reduction = value
|
||||
else:
|
||||
metric_value, metric_reduction = value, reduction
|
||||
|
||||
self._stored_metrics[train_eval][key]["values"].append(metric_value)
|
||||
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Custom trainer for diffusion LM training."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
@@ -16,14 +18,35 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = None
|
||||
self._config = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.config = config
|
||||
self._config = config
|
||||
self._cache_special_token_ids()
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids is required for diffusion training")
|
||||
|
||||
loss, outputs = self._compute_diffusion_loss(model, input_ids, attention_mask)
|
||||
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
|
||||
return loss
|
||||
|
||||
def _cache_special_token_ids(self):
|
||||
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||
if self.processing_class is None:
|
||||
@@ -42,7 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
self._special_token_ids = special_tokens
|
||||
|
||||
def forward_process(
|
||||
def _forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
@@ -90,14 +113,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
masked_indices = masked_indices & attention_mask.bool()
|
||||
|
||||
# Get mask token ID from config
|
||||
mask_token_id = self.config.mask_token_id
|
||||
mask_token_id = self._config.mask_token_id
|
||||
|
||||
# Create masked input using configured mask token
|
||||
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
def create_bidirectional_attention_mask(
|
||||
def _create_bidirectional_attention_mask(
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -115,7 +138,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
if attention_mask is None or not self.config.sample_packing:
|
||||
if attention_mask is None or not self._config.sample_packing:
|
||||
# Simple case: no attention mask, allow all-to-all attention
|
||||
return torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||
@@ -133,12 +156,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
|
||||
return bidirectional_mask
|
||||
|
||||
def compute_diffusion_loss(
|
||||
def _compute_diffusion_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | Any]:
|
||||
"""
|
||||
Compute diffusion loss.
|
||||
|
||||
@@ -152,12 +175,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
metrics: Dictionary of metrics.
|
||||
"""
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self.forward_process(
|
||||
input_ids, attention_mask, self.config.eps
|
||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||
input_ids, attention_mask, self._config.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = self.create_bidirectional_attention_mask(
|
||||
bidirectional_mask = self._create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
|
||||
@@ -187,7 +210,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
)
|
||||
|
||||
# Apply importance weighting if enabled
|
||||
if self.config.importance_weighting:
|
||||
if self._config.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
@@ -211,40 +234,15 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors
|
||||
"loss": loss.item(),
|
||||
"accuracy": accuracy.item(),
|
||||
"mask_ratio": masked_indices.float().mean().item(),
|
||||
"num_masked_tokens": masked_indices.sum().item(),
|
||||
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
|
||||
"avg_p_mask": p_mask[masked_indices].mean().item(),
|
||||
"ce_loss": ce_loss.item(),
|
||||
}
|
||||
|
||||
if self.config.importance_weighting:
|
||||
if self._config.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
return loss, metrics
|
||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||
self.store_metrics(metrics, train_eval=train_eval)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids is required for diffusion training")
|
||||
|
||||
loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask)
|
||||
|
||||
# # Log metrics
|
||||
# if self.state.is_local_process_zero:
|
||||
# for key, value in metrics.items():
|
||||
# self.log({f"train/diffusion_{key}": value})
|
||||
|
||||
if return_outputs:
|
||||
# TODO: compute outputs (?)
|
||||
outputs = [loss]
|
||||
return (loss, outputs)
|
||||
|
||||
return loss
|
||||
return loss, outputs
|
||||
|
||||
@@ -667,6 +667,23 @@ class ModelLoader:
|
||||
|
||||
return hf_ds_cfg
|
||||
|
||||
def _load_model_from_config(self) -> PreTrainedModel:
|
||||
"""Load model with random initialization using from_config."""
|
||||
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
|
||||
return self.auto_model_loader.from_config(config=self.model_config)
|
||||
return self.auto_model_loader(config=self.model_config)
|
||||
|
||||
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
|
||||
"""Load model from pretrained weights."""
|
||||
loader = model_loader_class or self.auto_model_loader
|
||||
kwargs = {
|
||||
**self.model_kwargs,
|
||||
"config": self.model_config,
|
||||
"trust_remote_code": self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return loader.from_pretrained(self.base_model, **kwargs)
|
||||
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
@@ -681,7 +698,8 @@ class ModelLoader:
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
|
||||
# _set_device_map
|
||||
if (
|
||||
"device_map" in self.model_kwargs
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
@@ -710,6 +728,11 @@ class ModelLoader:
|
||||
or self.cfg.qlora_sharded_model_loading
|
||||
)
|
||||
):
|
||||
if self.cfg.reinit_weights:
|
||||
LOG.warning(
|
||||
"reinit_weights is not supported with sharded quantized loading. "
|
||||
"Loading from pretrained weights instead."
|
||||
)
|
||||
quant_storage = self.cfg.torch_dtype
|
||||
quantization_config = getattr(
|
||||
self.model_config, "quantization_config", None
|
||||
@@ -725,33 +748,12 @@ class ModelLoader:
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
self.model_config.model_type in ["llama", "llama4"]
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
if self.auto_model_loader in [
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
]:
|
||||
self.model = self.auto_model_loader.from_config(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader(config=self.model_config)
|
||||
else:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.model_type == "MambaLMHeadModel":
|
||||
if self.cfg.reinit_weights:
|
||||
LOG.warning(
|
||||
"reinit_weights is not supported with MambaLMHeadModel. "
|
||||
"Loading from pretrained weights instead."
|
||||
)
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
@@ -764,41 +766,27 @@ class ModelLoader:
|
||||
self.base_model,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.model_type
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
):
|
||||
if self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
# Please don't remove underscore binding without reading the fn docstring
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_type
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# Use model type from transformers
|
||||
model_loader_class = getattr(transformers, self.model_type)
|
||||
else:
|
||||
# Use auto model loader (handles gptq and default cases)
|
||||
model_loader_class = self.auto_model_loader
|
||||
|
||||
if self.cfg.reinit_weights:
|
||||
self.model = self._load_model_from_config()
|
||||
else:
|
||||
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
|
||||
@@ -109,6 +109,12 @@ class AxolotlInputConfig(
|
||||
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
|
||||
},
|
||||
)
|
||||
reinit_weights: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
|
||||
Reference in New Issue
Block a user