diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index b7a53c4d5..7e49846c2 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -153,9 +153,12 @@ class PatchManager: from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn if self.cfg.chunked_cross_entropy_num_chunks: - patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) + patch_chunked_ce_loss_fn( + self.cfg.chunked_cross_entropy_num_chunks, + use_dft=self.cfg.use_dynamic_finetuning, + ) else: - patch_chunked_ce_loss_fn() + patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning) def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py index 26a52f898..f7fc26364 100644 --- a/src/axolotl/monkeypatch/loss/chunked.py +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module): For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 """ - def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + def __init__( + self, + num_output_chunks: int = 8, + ignore_index: int = -100, + use_dft: bool = False, + ): super().__init__() self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index + self.use_dft = use_dft def compute_cross_entropy( self, @@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module): """ Upcast logits to fp32 and compute cross entropy loss. """ - return F.cross_entropy( - logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ce_loss = F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="none" ) + if self.use_dft: + # Compute probabilities and gather the ones corresponding to labels + with torch.no_grad(): # Stop gradient + probs = torch.softmax(logits.float(), dim=-1) + # Create mask for valid tokens (not ignore_index) + valid_mask = labels != self.ignore_index + # Gather probabilities for the correct tokens + label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + # Apply mask to only scale valid tokens + label_probs = label_probs * valid_mask + # Avoid multiplication by 0 for ignored tokens + label_probs = torch.where( + valid_mask, label_probs, torch.ones_like(label_probs) + ) + + # Scale the loss by the probability (DFT) + ce_loss = ce_loss * label_probs + + return ce_loss.sum() + def forward( self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" ) -> torch.Tensor: @@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module): return total_loss / total_elements -def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): - loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) +def _build_chunked_ce_loss_fn( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): + loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft) loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy, backend="inductor" ) return loss_fn_ce -def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): - loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) +def get_causal_lm_loss( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): + loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft) def chunked_fix_cross_entropy( source, @@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): return for_causal_lm_chunked_loss -def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): +def patch_chunked_ce_loss_fn( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): import transformers.loss.loss_utils - for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) + for_causal_lm_chunked_loss = get_causal_lm_loss( + num_output_chunks, ignore_index, use_dft + ) transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( for_causal_lm_chunked_loss diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index da21df7aa..4bf252cb7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -664,6 +664,13 @@ class AxolotlInputConfig( }, ) + use_dynamic_finetuning: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use dynamic fine-tuning for scaled SFT gradients." + }, + ) + chunked_cross_entropy: bool | None = Field( default=None, json_schema_extra={