diff --git a/src/axolotl/monkeypatch/loss/__init__.py b/src/axolotl/monkeypatch/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py new file mode 100644 index 000000000..d1f9b32b9 --- /dev/null +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -0,0 +1,129 @@ +""" +chunked ce loss +""" + +from typing import List, Optional + +import torch +import torch.nn.functional as F + + +# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss +class CEWithChunkedOutputLoss(torch.nn.Module): + """ + Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time. + + For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 + """ + + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + super().__init__() + self.num_output_chunks = num_output_chunks + self.ignore_index = ignore_index + + def compute_cross_entropy( + self, + logits: torch.Tensor, + labels: torch.Tensor, + normalize: bool = True, # pylint: disable=unused-argument + ) -> torch.Tensor: + """ + Upcast logits to fp32 and compute cross entropy loss. + """ + return F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ) + + def forward( + self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" + ) -> torch.Tensor: + """ + Args: + logits (List[torch.Tensor]): List of chunked logits of length + ``self.num_output_chunks``, where each chunk has shape + ``(batch_size, num_tokens / num_output_chunks, vocab_size)``. + labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``. + reduction (str): The reduction to apply to the output. + + Returns: + torch.Tensor: Cross entropy loss of shape (1,). + """ + + total_elements = (labels != self.ignore_index).sum() + + # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] + labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self.num_output_chunks, dim=1) + ] + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] + logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + ] + + # compute one chunk at a time + total_loss = 0.0 + for logits_chunk, labels_chunk in zip(logits, labels): + total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) + + if reduction == "sum": + return total_loss + return total_loss / total_elements + + +def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) + loss_fn_ce.compute_cross_entropy = torch.compile( + loss_fn_ce.compute_cross_entropy, backend="inductor" + ) + return loss_fn_ce + + +def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + import transformers.loss.loss_utils + + loss_fn_ce = build_chunked_ce_loss_fn(num_output_chunks, ignore_index) + + def chunked_fix_cross_entropy( + source, + target, + num_items_in_batch: int = None, + ignore_index: int = -100, + **kwargs, + ): # pylint: disable=unused-argument + reduction = "sum" if num_items_in_batch is not None else "mean" + logit_chunks = [ # pylint: disable=unnecessary-comprehension + chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1) + ] + loss = loss_fn_ce(logit_chunks, target, reduction=reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss + + def for_causal_lm_chunked_loss( + logits, + labels, + vocab_size: int, # pylint: disable=unused-argument + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # skip the upcast to float since we handle that in the chunking loss + if shift_labels is None: + # Shift so that tokens < n predict n + labels = F.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Skip Flattening the tokens + # Enable model parallelism + shift_labels = shift_labels.to(logits.device) + loss = chunked_fix_cross_entropy( + logits, shift_labels, num_items_in_batch, ignore_index, **kwargs + ) + return loss + + transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss + transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( + for_causal_lm_chunked_loss + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index eaaa2a450..95dae7e20 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -556,6 +556,14 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: + if self.cfg.chunked_cross_entropy: + from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn + + if self.cfg.chunked_cross_entropy_num_chunks: + patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) + else: + patch_chunked_ce_loss_fn() + if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 02308695c..40631515e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -242,6 +242,9 @@ class AxolotlInputConfig( unsloth_rms_norm: bool | None = None unsloth_rope: bool | None = None + chunked_cross_entropy: bool | None = None + chunked_cross_entropy_num_chunks: int | None = None + lora_mlp_kernel: bool | None = None lora_qkv_kernel: bool | None = None lora_o_kernel: bool | None = None