From cf8c93e2ee0920d2e223788db1c480fc8c6f13fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Aug 2025 09:36:57 -0400 Subject: [PATCH] wip --- src/axolotl/core/trainers/base.py | 12 ++ src/axolotl/integrations/diffusion/trainer.py | 117 +++++++++++++----- src/axolotl/integrations/diffusion/utils.py | 50 ++++++++ src/axolotl/integrations/spectrum/__init__.py | 2 +- src/axolotl/utils/environment.py | 2 +- 5 files changed, 151 insertions(+), 32 deletions(-) create mode 100644 src/axolotl/integrations/diffusion/utils.py diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index c433b2a39..af7933793 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -274,6 +274,18 @@ class AxolotlTrainer( num_workers=self.args.dataloader_num_workers, rank=self.args.process_index, ) + if (self.args.accelerator_config is not None + and self.args.accelerator_config.split_batches + and self.args.accelerator_config.dispatch_batches + ): + if self.args.sample_packing and self.args.pretraining: + if not self.args.eval_sample_packing and not is_training: + dataloader_params["batch_size"] *= self.accelerator.num_processes + else: + dataloader_params["batch_size"] = self.accelerator.num_processes + elif not self.args.sample_packing and self.args.pretraining: + dataloader_params["batch_size"] *= self.accelerator.num_processes + if self.args.sample_packing and ( (is_training and not self.args.pretraining) or (not is_training and self.args.eval_sample_packing is not False) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index dc62035d5..bf0916370 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -5,8 +5,10 @@ from typing import Any, Literal import torch import torch.nn.functional as F from torch import nn +from transformers.masking_utils import find_packed_sequence_indices from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.integrations.diffusion.utils import create_bidirectional_block_mask from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -43,12 +45,13 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") labels = inputs.get("labels") + position_ids = inputs.get("position_ids") if input_ids is None: raise ValueError("input_ids is required for diffusion training") loss, outputs = self._compute_diffusion_loss( - model, input_ids, attention_mask, labels + model, input_ids, attention_mask, labels, position_ids ) if return_outputs: @@ -80,6 +83,8 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, eps: float = 1e-3, + min_p: float = 0.0, + max_p: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward noising process. A timestep is sampled along the process, and tokens are @@ -103,7 +108,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors t = torch.rand(batch_size, device=device) # Calculate masking probability with epsilon - p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = min_p + (max_p - min_p) * (1 - eps) * t + eps # [batch_size] p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] # Don't mask padding tokens if attention_mask is provided @@ -136,7 +141,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors @torch.compile def _create_bidirectional_attention_mask( - self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None ) -> torch.Tensor: """ Create bidirectional attention mask to override default causal masking. Handles @@ -146,6 +151,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors Args: input_ids: Input token ids [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len] + position_ids: Position ids [batch_size, seq_len] Returns: bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]. @@ -158,17 +164,28 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device ) - # Create attention mask by comparing sample IDs element-wise - mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] - mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + if position_ids is None: + # Create attention mask by comparing sample IDs element-wise + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] - # Tokens can attend to each other if they have the same non-zero sample ID - bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) - # Add head dimension: [batch_size, 1, seq_len, seq_len] - bidirectional_mask = bidirectional_mask.unsqueeze(1) + # Add head dimension: [batch_size, 1, seq_len, seq_len] + bidirectional_mask = bidirectional_mask.unsqueeze(1) - return bidirectional_mask + return bidirectional_mask + + if self._config.flex_attention: + block_mask = create_bidirectional_block_mask( + input_ids, attention_mask, position_ids + ) + else: + packed_seq_mask = find_packed_sequence_indices(position_ids) + block_mask = packed_seq_mask.unsqueeze(2) == packed_seq_mask.unsqueeze(1) + + return block_mask def _compute_diffusion_loss( self, @@ -176,6 +193,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | Any]: """ Compute diffusion loss. @@ -185,6 +203,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors input_ids: Ground truth token ids [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. labels: Labels for SFT training [batch_size, seq_len]. + position_ids: Position ids [batch_size, seq_len]. Returns: loss: Cross-entropy loss. @@ -192,12 +211,12 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors """ # Apply forward process noisy_batch, masked_indices, p_mask = self._forward_process( - input_ids, attention_mask, labels, self.config.eps + input_ids, attention_mask, labels, self._config.eps, self._config.min_mask_ratio, self._config.max_mask_ratio ) - # Create bidirectional attention mask + # Create bidirectional attention mask (optional: use causal if you want strict AR behavior) bidirectional_mask = self._create_bidirectional_attention_mask( - input_ids, attention_mask + input_ids, attention_mask, position_ids ) # Forward pass @@ -205,15 +224,31 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors input_ids=noisy_batch, attention_mask=bidirectional_mask, ) - logits = outputs.logits + logits = outputs.logits # [B, L, V] - if masked_indices.sum() > 0: - valid_indices = torch.where(masked_indices) + # ----- AR label shift toggle ----- + use_ar_shift = False + if use_ar_shift: + # Predict token at t from logits at t-1: drop last logit step, drop first target step + logits_eff = logits[:, :-1, :] + input_ids_eff = input_ids[:, 1:] + masked_indices_eff = masked_indices[:, 1:] + p_mask_eff = p_mask[:, 1:] + labels_eff = labels[:, 1:] if labels is not None else None + else: + logits_eff = logits + input_ids_eff = input_ids + masked_indices_eff = masked_indices + p_mask_eff = p_mask + labels_eff = labels + + if masked_indices_eff.sum() > 0: + valid_indices = torch.where(masked_indices_eff) batch_indices, seq_indices = valid_indices - masked_logits = logits[batch_indices, seq_indices] - masked_targets = input_ids[batch_indices, seq_indices] - masked_p_mask = p_mask[batch_indices, seq_indices] + masked_logits = logits_eff[batch_indices, seq_indices] + masked_targets = input_ids_eff[batch_indices, seq_indices] + masked_p_mask = p_mask_eff[batch_indices, seq_indices] # Compute cross-entropy loss without reduction token_loss = F.cross_entropy( @@ -221,15 +256,15 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors ) if self.config.importance_weighting: - masked_p_mask = masked_p_mask.float() + masked_p_mask = masked_p_mask.float().clamp_min(1e-6) weighted_loss = token_loss / masked_p_mask else: weighted_loss = token_loss # Final loss: sum weighted losses, normalize - if labels is not None: + if labels_eff is not None: # For SFT data: normalize by answer length per sample - answer_mask = labels != -100 + answer_mask = labels_eff != -100 answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] # Get batch indices for masked tokens @@ -241,7 +276,7 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors ) for i in range(input_ids.shape[0]): sample_mask = masked_batch_indices == i - if sample_mask.sum() > 0: + if sample_mask.any(): sample_loss = weighted_loss[sample_mask].sum() loss_per_sample[i] = sample_loss / answer_lengths[i] @@ -262,14 +297,36 @@ class DiffusionTrainer(AxolotlTrainer): # pylint: disable=too-many-ancestors ce_loss = torch.tensor(0.0, device=input_ids.device) masked_p_mask = torch.tensor(1.0, device=input_ids.device) + # Keep eff tensors around for metrics + masked_indices_eff = masked_indices + p_mask_eff = p_mask + labels_eff = labels + + # Metrics (aligned to the effective tensors) + if masked_indices_eff.any(): + avg_p = p_mask_eff[masked_indices_eff].float().mean().item() + num_masked = int(masked_indices_eff.sum().item()) + mask_ratio = masked_indices_eff.float().mean().item() + else: + avg_p = 0.0 + num_masked = 0 + mask_ratio = 0.0 + metrics = { - "loss": loss.item(), - "accuracy": accuracy.item(), - "mask_ratio": masked_indices.float().mean().item(), - "num_masked_tokens": (masked_indices.sum().item(), "sum"), - "avg_p_mask": p_mask[masked_indices].mean().item(), - "ce_loss": ce_loss.item(), + "loss": float(loss.detach()), + "accuracy": float(accuracy.detach()), + "mask_ratio": mask_ratio, + "num_masked_tokens": (num_masked, "sum"), + "avg_p_mask": avg_p, + "ce_loss": float(ce_loss.detach()), } + + # SFT-specific metrics (aligned) + if labels_eff is not None: + answer_mask = labels_eff != -100 + metrics["answer_ratio"] = answer_mask.float().mean().item() + metrics["avg_answer_length"] = answer_mask.sum(dim=1).float().mean().item() + if self.config.importance_weighting: metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py new file mode 100644 index 000000000..7c8cdf584 --- /dev/null +++ b/src/axolotl/integrations/diffusion/utils.py @@ -0,0 +1,50 @@ +import torch +from torch.nn.attention.flex_attention import BlockMask, create_block_mask +from transformers.masking_utils import find_packed_sequence_indices, packed_sequence_mask_function + + +def create_bidirectional_block_mask( + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, +) -> "BlockMask": + """ + Creates a bidirectional block mask for FlexAttention. + + Args: + input_ids: Input token ids [batch_size, seq_len] + attention_mask: Padding mask [batch_size, seq_len] + + Returns: + BlockMask for bidirectional attention with padding + """ + batch_size, seq_len = input_ids.shape + + if position_ids is not None: + packed_seq_mask = find_packed_sequence_indices(position_ids) + mask_fn =packed_sequence_mask_function(packed_seq_mask, batch_size, seq_len) + elif attention_mask is None: + # If no padding mask, all positions can attend to all positions + def mask_fn(b, h, q_idx, kv_idx): + # Always return True for bidirectional attention + return True + else: + # Convert attention_mask to boolean if needed + attention_mask = attention_mask.bool() + + def mask_fn(b, h, q_idx, kv_idx): + # Both query and key positions must be valid (not padding) + return attention_mask[b, q_idx] & attention_mask[b, kv_idx] + + # Create the block mask + block_mask = create_block_mask( + mask_fn, + B=batch_size, + H=None, # Will be set by the attention layer + Q_LEN=seq_len, + KV_LEN=seq_len, + device=input_ids.device, + _compile=True, + ) + + return block_mask diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py index 9f66aef97..d78f0003a 100644 --- a/src/axolotl/integrations/spectrum/__init__.py +++ b/src/axolotl/integrations/spectrum/__init__.py @@ -57,7 +57,7 @@ class SpectrumPlugin(BasePlugin): Spectrum Plugin to automatically generate unfrozen parameters based on SNR data. """ - base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/" + base_url = "https://raw.githubusercontent.com/QuixiAI/spectrum/main/model_snr_results/" base_path = "./model_snr_results/" snr_file_template = "snr_results_{model_name_slug}.json" diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 3c83c87cb..7516129f7 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -16,7 +16,7 @@ from packaging.version import Version, parse def check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support(): return False - unsupported_devices = {"RTX 6000 Ada", "L40S"} + unsupported_devices = {"RTX 6000 Ada", "L40S", "A40"} try: device_names, device_count = get_gpu_info() if 1 < device_count < 8: