Files
axolotl/tests/test_chunked_xentropy.py
Wing Lian 12c826816d chunked cross entropy loss (#2625)
* chunked cross entropy loss

* refactor so we can add test

* use relative import

* update schema description
2025-06-23 23:08:46 -04:00

41 lines
1.1 KiB
Python

"""
test suite for chunked cross entropy
"""
import pytest
import torch
from torch import nn
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
@pytest.fixture
def chunked_fixtures():
model_dim = 512
vocab_size = 1024 * 256
seq_len = 2048
batch_size = 1
lm_head = nn.Linear(model_dim, vocab_size)
hidden_state = torch.randn(batch_size, seq_len, model_dim)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()
logits = lm_head(hidden_state)
chunked_lm_loss = lm_loss(logits, labels)
logits_flattened = logits.view(-1, vocab_size)
labels_flattened = labels.view(-1)
loss = nn.functional.cross_entropy(
logits_flattened.float(), labels_flattened, reduction="mean"
)
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)