chunked cross entropy loss (#2625)

* chunked cross entropy loss

* refactor so we can add test

* use relative import

* update schema description
This commit is contained in:
Wing Lian
2025-06-23 23:08:46 -04:00
committed by GitHub
parent 1d8f500709
commit 12c826816d
6 changed files with 198 additions and 1 deletions

View File

@@ -71,7 +71,7 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
from .monkeypatch.patch import (
cce_patch,
)

View File

@@ -50,6 +50,7 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_flex_attention_patches()
@@ -78,6 +79,15 @@ class PatchManager:
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
def _apply_chunked_cross_entropy_patch(self):
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":

View File

View File

@@ -0,0 +1,134 @@
"""
chunked ce loss
"""
from typing import List, Optional
import torch
import torch.nn.functional as F
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
def compute_cross_entropy(
self,
logits: torch.Tensor,
labels: torch.Tensor,
normalize: bool = True, # pylint: disable=unused-argument
) -> torch.Tensor:
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
"""
Args:
logits (List[torch.Tensor]): List of chunked logits of length
``self.num_output_chunks``, where each chunk has shape
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
reduction (str): The reduction to apply to the output.
Returns:
torch.Tensor: Cross entropy loss of shape (1,).
"""
total_elements = (labels != self.ignore_index).sum()
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
labels = [
target_chunk.reshape(-1)
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
]
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
logits = [
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
]
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
if reduction == "sum":
return total_loss
return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def chunked_fix_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
logit_chunks = [ # pylint: disable=unnecessary-comprehension
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
]
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def for_causal_lm_chunked_loss(
logits,
labels,
vocab_size: int = None, # pylint: disable=unused-argument
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# skip the upcast to float since we handle that in the chunking loss
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Skip Flattening the tokens
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = chunked_fix_cross_entropy(
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss
)

View File

@@ -523,6 +523,19 @@ class AxolotlInputConfig(
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use chunked cross entropy loss for memory efficiency"
},
)
chunked_cross_entropy_num_chunks: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(