From 52cab2aa5b5bfd181235c28ba8662012c2696320 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 3 May 2025 21:47:45 -0400 Subject: [PATCH] refactor so we can add test --- src/axolotl/monkeypatch/loss/chunked.py | 17 +++++++---- tests/test_chunked_xentropy.py | 40 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) create mode 100644 tests/test_chunked_xentropy.py diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py index d1f9b32b9..0a9d0de82 100644 --- a/src/axolotl/monkeypatch/loss/chunked.py +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -71,7 +71,7 @@ class CEWithChunkedOutputLoss(torch.nn.Module): return total_loss / total_elements -def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): +def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy, backend="inductor" @@ -79,10 +79,8 @@ def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10 return loss_fn_ce -def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): - import transformers.loss.loss_utils - - loss_fn_ce = build_chunked_ce_loss_fn(num_output_chunks, ignore_index) +def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) def chunked_fix_cross_entropy( source, @@ -103,7 +101,7 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10 def for_causal_lm_chunked_loss( logits, labels, - vocab_size: int, # pylint: disable=unused-argument + vocab_size: int = None, # pylint: disable=unused-argument num_items_in_batch: Optional[int] = None, ignore_index: int = -100, shift_labels: Optional[torch.Tensor] = None, @@ -123,6 +121,13 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10 ) return loss + return for_causal_lm_chunked_loss + + +def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + import transformers.loss.loss_utils + + for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( for_causal_lm_chunked_loss diff --git a/tests/test_chunked_xentropy.py b/tests/test_chunked_xentropy.py new file mode 100644 index 000000000..3e439f0a3 --- /dev/null +++ b/tests/test_chunked_xentropy.py @@ -0,0 +1,40 @@ +""" +test suite for chunked cross entropy +""" + +import pytest +import torch +from torch import nn + +from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss + + +@pytest.fixture +def chunked_fixtures(): + model_dim = 512 + vocab_size = 1024 * 256 + seq_len = 2048 + batch_size = 1 + + lm_head = nn.Linear(model_dim, vocab_size) + hidden_state = torch.randn(batch_size, seq_len, model_dim) + labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + return lm_head, hidden_state, labels, vocab_size + + +def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name + lm_head, hidden_state, labels, vocab_size = chunked_fixtures + lm_loss = get_causal_lm_loss() + + logits = lm_head(hidden_state) + + chunked_lm_loss = lm_loss(logits, labels) + + logits_flattened = logits.view(-1, vocab_size) + labels_flattened = labels.view(-1) + + loss = nn.functional.cross_entropy( + logits_flattened.float(), labels_flattened, reduction="mean" + ) + + assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)