From e19be0c2d945260ff61186e0dc7e845773fa7c7b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 15 Aug 2025 02:21:25 +0000 Subject: [PATCH] add back in reinit_weights (clobbered?); masking / pretrain fixes --- .../llama-3/diffusion-3.2-1b-pretrain.yaml | 6 +- src/axolotl/core/trainers/base.py | 48 +++++++- src/axolotl/integrations/diffusion/trainer.py | 82 +++++++------ src/axolotl/loaders/model.py | 108 ++++++++---------- src/axolotl/utils/schemas/config.py | 6 + 5 files changed, 139 insertions(+), 111 deletions(-) diff --git a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml index 95d820cca..965e248eb 100644 --- a/examples/llama-3/diffusion-3.2-1b-pretrain.yaml +++ b/examples/llama-3/diffusion-3.2-1b-pretrain.yaml @@ -34,7 +34,7 @@ lr_scheduler: cosine learning_rate: 3e-4 bf16: auto -tf32: false +tf32: true gradient_checkpointing: true resume_from_checkpoint: @@ -51,8 +51,8 @@ eval_steps: 1000 special_tokens: pad_token: "<|end_of_text|>" -wandb_project: diffusion-plugin -wandb_entity: axolotl-ai +wandb_project: +wandb_entity: wandb_watch: wandb_name: wandb_log_model: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0f9f6e4c4..77a9cc83a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -82,7 +82,9 @@ class AxolotlTrainer( super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self._stored_metrics = defaultdict( + lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) + ) if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -573,9 +575,26 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + + # Add reduced stored metrics to logs + for key, metric_data in self._stored_metrics[train_eval].items(): + values = torch.tensor(metric_data["values"]) + reduction_type = metric_data["reduction"] + + if reduction_type == "mean": + logs[key] = values.mean().item() + elif reduction_type == "min": + logs[key] = values.min().item() + elif reduction_type == "max": + logs[key] = values.max().item() + elif reduction_type == "sum": + logs[key] = values.sum().item() + else: + raise NotImplementedError( + "Metric reduction must be one of [mean, min, max]" + ) + + logs[key] = round(logs[key], 4) if is_main_process(): # Add memory usage @@ -592,10 +611,27 @@ class AxolotlTrainer( return super().log(logs, start_time) def store_metrics( - self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, + metrics: dict[str, float] | dict[str, tuple[int | float, str]], + train_eval: Literal["train", "eval"] = "train", + reduction: Literal["mean", "min", "max", "sum"] = "mean", ) -> None: + """ + Store metrics with specified reduction type. + + Args: + metrics: Dictionary of metric names to values, or metric names to (value, + reduction_type) tuples. + train_eval: Whether this is for training or evaluation. + """ for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) + if isinstance(value, tuple): + metric_value, metric_reduction = value + else: + metric_value, metric_reduction = value, reduction + + self._stored_metrics[train_eval][key]["values"].append(metric_value) + self._stored_metrics[train_eval][key]["reduction"] = metric_reduction def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 160b5692b..ed81fd029 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -1,5 +1,7 @@ """Custom trainer for diffusion LM training.""" +from typing import Any, Literal + import torch import torch.nn.functional as F from torch import nn @@ -16,14 +18,35 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.config = None + self._config = None self._special_token_ids = None def set_config(self, config: DictDefault): """Set config for diffusion training.""" - self.config = config + self._config = config self._cache_special_token_ids() + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor], + return_outputs: bool = False, + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Override compute_loss to use diffusion loss.""" + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + + if input_ids is None: + raise ValueError("input_ids is required for diffusion training") + + loss, outputs = self._compute_diffusion_loss(model, input_ids, attention_mask) + + if return_outputs: + return loss, outputs + + return loss + def _cache_special_token_ids(self): """Cache special token IDs to avoid repeated tokenizer access.""" if self.processing_class is None: @@ -42,7 +65,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors self._special_token_ids = special_tokens - def forward_process( + def _forward_process( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, @@ -90,14 +113,14 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors masked_indices = masked_indices & attention_mask.bool() # Get mask token ID from config - mask_token_id = self.config.mask_token_id + mask_token_id = self._config.mask_token_id # Create masked input using configured mask token noisy_batch = torch.where(masked_indices, mask_token_id, input_ids) return noisy_batch, masked_indices, p_mask - def create_bidirectional_attention_mask( + def _create_bidirectional_attention_mask( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None ) -> torch.Tensor: """ @@ -115,7 +138,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors batch_size, seq_len = input_ids.shape device = input_ids.device - if attention_mask is None or not self.config.sample_packing: + if attention_mask is None or not self._config.sample_packing: # Simple case: no attention mask, allow all-to-all attention return torch.ones( batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device @@ -133,12 +156,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors return bidirectional_mask - def compute_diffusion_loss( + def _compute_diffusion_loss( self, model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, dict[str, float]]: + ) -> tuple[torch.Tensor, torch.Tensor | Any]: """ Compute diffusion loss. @@ -152,12 +175,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors metrics: Dictionary of metrics. """ # Apply forward process - noisy_batch, masked_indices, p_mask = self.forward_process( - input_ids, attention_mask, self.config.eps + noisy_batch, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, self._config.eps ) # Create bidirectional attention mask - bidirectional_mask = self.create_bidirectional_attention_mask( + bidirectional_mask = self._create_bidirectional_attention_mask( input_ids, attention_mask ) @@ -187,7 +210,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors ) # Apply importance weighting if enabled - if self.config.importance_weighting: + if self._config.importance_weighting: masked_p_mask = masked_p_mask.float() weighted_loss = token_loss / masked_p_mask else: @@ -211,40 +234,15 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors "loss": loss.item(), "accuracy": accuracy.item(), "mask_ratio": masked_indices.float().mean().item(), - "num_masked_tokens": masked_indices.sum().item(), + "num_masked_tokens": (masked_indices.sum().item(), "sum"), "avg_p_mask": p_mask[masked_indices].mean().item(), "ce_loss": ce_loss.item(), } - if self.config.importance_weighting: + if self._config.importance_weighting: metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() - return loss, metrics + train_eval: Literal["train", "eval"] = "train" if model.training else "eval" + self.store_metrics(metrics, train_eval=train_eval) - def compute_loss( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor], - return_outputs: bool = False, - num_items_in_batch: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - """Override compute_loss to use diffusion loss.""" - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask") - - if input_ids is None: - raise ValueError("input_ids is required for diffusion training") - - loss, metrics = self.compute_diffusion_loss(model, input_ids, attention_mask) - - # # Log metrics - # if self.state.is_local_process_zero: - # for key, value in metrics.items(): - # self.log({f"train/diffusion_{key}": value}) - - if return_outputs: - # TODO: compute outputs (?) - outputs = [loss] - return (loss, outputs) - - return loss + return loss, outputs diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6bf1f149b..1cb33d13c 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -667,6 +667,23 @@ class ModelLoader: return hf_ds_cfg + def _load_model_from_config(self) -> PreTrainedModel: + """Load model with random initialization using from_config.""" + if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]: + return self.auto_model_loader.from_config(config=self.model_config) + return self.auto_model_loader(config=self.model_config) + + def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel: + """Load model from pretrained weights.""" + loader = model_loader_class or self.auto_model_loader + kwargs = { + **self.model_kwargs, + "config": self.model_config, + "trust_remote_code": self.cfg.trust_remote_code or False, + **self.model_kwargs, + } + return loader.from_pretrained(self.base_model, **kwargs) + def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False @@ -681,7 +698,8 @@ class ModelLoader: if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: skip_move_to_device = True - # Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map + # Don't delete device_map for QLoRA + FSDP - it was set correctly in + # _set_device_map if ( "device_map" in self.model_kwargs and not self.is_qlora_and_fsdp_enabled @@ -710,6 +728,11 @@ class ModelLoader: or self.cfg.qlora_sharded_model_loading ) ): + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with sharded quantized loading. " + "Loading from pretrained weights instead." + ) quant_storage = self.cfg.torch_dtype quantization_config = getattr( self.model_config, "quantization_config", None @@ -725,33 +748,12 @@ class ModelLoader: quantization_config=quantization_config, ) skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # Please don't remove underscore binding without reading the fn docstring. - _ = self._configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader(config=self.model_config) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) elif self.model_type == "MambaLMHeadModel": + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with MambaLMHeadModel. " + "Loading from pretrained weights instead." + ) # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name @@ -764,41 +766,27 @@ class ModelLoader: self.base_model, **self.model_kwargs, ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - elif self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) else: - # Please don't remove underscore binding without reading the fn docstring. + # Please don't remove underscore binding without reading the fn docstring _ = self._configure_zero3_memory_efficient_loading() - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) + + if ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # Use model type from transformers + model_loader_class = getattr(transformers, self.model_type) + else: + # Use auto model loader (handles gptq and default cases) + model_loader_class = self.auto_model_loader + + if self.cfg.reinit_weights: + self.model = self._load_model_from_config() + else: + self.model = self._load_model_from_pretrained(model_loader_class) + if is_deepspeed_zero3_enabled(): skip_move_to_device = True diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 21e99c048..9cd98a4b2 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -109,6 +109,12 @@ class AxolotlInputConfig( "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" }, ) + reinit_weights: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Reinitialize model weights randomly instead of loading pretrained weights" + }, + ) trainer_cls: str | None = Field( default=None,