diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 9b155ca8a..c29bb55d4 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -71,7 +71,7 @@ class CutCrossEntropyPlugin(BasePlugin): if cfg.cut_cross_entropy: self._check_requirements() - from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import ( + from .monkeypatch.patch import ( cce_patch, ) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index ca8fd1258..7c4e75796 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -50,6 +50,7 @@ class PatchManager: def apply_pre_model_load_patches(self): """Apply pre-model load patches based on config.""" self._apply_flash_attention_patches() + self._apply_chunked_cross_entropy_patch() self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_flex_attention_patches() @@ -78,6 +79,15 @@ class PatchManager: patch_xformers_attn_over_fa2() self.cfg.flash_attention = True + def _apply_chunked_cross_entropy_patch(self): + if self.cfg.chunked_cross_entropy: + from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn + + if self.cfg.chunked_cross_entropy_num_chunks: + patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) + else: + patch_chunked_ce_loss_fn() + def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": diff --git a/src/axolotl/monkeypatch/loss/__init__.py b/src/axolotl/monkeypatch/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py new file mode 100644 index 000000000..0a9d0de82 --- /dev/null +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -0,0 +1,134 @@ +""" +chunked ce loss +""" + +from typing import List, Optional + +import torch +import torch.nn.functional as F + + +# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss +class CEWithChunkedOutputLoss(torch.nn.Module): + """ + Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time. + + For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 + """ + + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + super().__init__() + self.num_output_chunks = num_output_chunks + self.ignore_index = ignore_index + + def compute_cross_entropy( + self, + logits: torch.Tensor, + labels: torch.Tensor, + normalize: bool = True, # pylint: disable=unused-argument + ) -> torch.Tensor: + """ + Upcast logits to fp32 and compute cross entropy loss. + """ + return F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ) + + def forward( + self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" + ) -> torch.Tensor: + """ + Args: + logits (List[torch.Tensor]): List of chunked logits of length + ``self.num_output_chunks``, where each chunk has shape + ``(batch_size, num_tokens / num_output_chunks, vocab_size)``. + labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``. + reduction (str): The reduction to apply to the output. + + Returns: + torch.Tensor: Cross entropy loss of shape (1,). + """ + + total_elements = (labels != self.ignore_index).sum() + + # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] + labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self.num_output_chunks, dim=1) + ] + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] + logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + ] + + # compute one chunk at a time + total_loss = 0.0 + for logits_chunk, labels_chunk in zip(logits, labels): + total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) + + if reduction == "sum": + return total_loss + return total_loss / total_elements + + +def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) + loss_fn_ce.compute_cross_entropy = torch.compile( + loss_fn_ce.compute_cross_entropy, backend="inductor" + ) + return loss_fn_ce + + +def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) + + def chunked_fix_cross_entropy( + source, + target, + num_items_in_batch: int = None, + ignore_index: int = -100, + **kwargs, + ): # pylint: disable=unused-argument + reduction = "sum" if num_items_in_batch is not None else "mean" + logit_chunks = [ # pylint: disable=unnecessary-comprehension + chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1) + ] + loss = loss_fn_ce(logit_chunks, target, reduction=reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss + + def for_causal_lm_chunked_loss( + logits, + labels, + vocab_size: int = None, # pylint: disable=unused-argument + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # skip the upcast to float since we handle that in the chunking loss + if shift_labels is None: + # Shift so that tokens < n predict n + labels = F.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Skip Flattening the tokens + # Enable model parallelism + shift_labels = shift_labels.to(logits.device) + loss = chunked_fix_cross_entropy( + logits, shift_labels, num_items_in_batch, ignore_index, **kwargs + ) + return loss + + return for_causal_lm_chunked_loss + + +def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + import transformers.loss.loss_utils + + for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) + transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss + transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( + for_causal_lm_chunked_loss + ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 460043272..c698fc3b6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -523,6 +523,19 @@ class AxolotlInputConfig( }, ) + chunked_cross_entropy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use chunked cross entropy loss for memory efficiency" + }, + ) + chunked_cross_entropy_num_chunks: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of chunks to use for chunked cross entropy loss" + }, + ) + llama4_linearized_experts: bool | None = None deepspeed: str | dict[str, Any] | None = Field( diff --git a/tests/test_chunked_xentropy.py b/tests/test_chunked_xentropy.py new file mode 100644 index 000000000..3e439f0a3 --- /dev/null +++ b/tests/test_chunked_xentropy.py @@ -0,0 +1,40 @@ +""" +test suite for chunked cross entropy +""" + +import pytest +import torch +from torch import nn + +from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss + + +@pytest.fixture +def chunked_fixtures(): + model_dim = 512 + vocab_size = 1024 * 256 + seq_len = 2048 + batch_size = 1 + + lm_head = nn.Linear(model_dim, vocab_size) + hidden_state = torch.randn(batch_size, seq_len, model_dim) + labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + return lm_head, hidden_state, labels, vocab_size + + +def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name + lm_head, hidden_state, labels, vocab_size = chunked_fixtures + lm_loss = get_causal_lm_loss() + + logits = lm_head(hidden_state) + + chunked_lm_loss = lm_loss(logits, labels) + + logits_flattened = logits.view(-1, vocab_size) + labels_flattened = labels.view(-1) + + loss = nn.functional.cross_entropy( + logits_flattened.float(), labels_flattened, reduction="mean" + ) + + assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)