refactor so we can add test
This commit is contained in:
@@ -71,7 +71,7 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
|
|||||||
return total_loss / total_elements
|
return total_loss / total_elements
|
||||||
|
|
||||||
|
|
||||||
def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
||||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||||
@@ -79,10 +79,8 @@ def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
|||||||
return loss_fn_ce
|
return loss_fn_ce
|
||||||
|
|
||||||
|
|
||||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
import transformers.loss.loss_utils
|
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||||
|
|
||||||
loss_fn_ce = build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
|
||||||
|
|
||||||
def chunked_fix_cross_entropy(
|
def chunked_fix_cross_entropy(
|
||||||
source,
|
source,
|
||||||
@@ -103,7 +101,7 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
|||||||
def for_causal_lm_chunked_loss(
|
def for_causal_lm_chunked_loss(
|
||||||
logits,
|
logits,
|
||||||
labels,
|
labels,
|
||||||
vocab_size: int, # pylint: disable=unused-argument
|
vocab_size: int = None, # pylint: disable=unused-argument
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[int] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
shift_labels: Optional[torch.Tensor] = None,
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
@@ -123,6 +121,13 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
|||||||
)
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
return for_causal_lm_chunked_loss
|
||||||
|
|
||||||
|
|
||||||
|
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
|
import transformers.loss.loss_utils
|
||||||
|
|
||||||
|
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
||||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||||
for_causal_lm_chunked_loss
|
for_causal_lm_chunked_loss
|
||||||
|
|||||||
40
tests/test_chunked_xentropy.py
Normal file
40
tests/test_chunked_xentropy.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
test suite for chunked cross entropy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chunked_fixtures():
|
||||||
|
model_dim = 512
|
||||||
|
vocab_size = 1024 * 256
|
||||||
|
seq_len = 2048
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
lm_head = nn.Linear(model_dim, vocab_size)
|
||||||
|
hidden_state = torch.randn(batch_size, seq_len, model_dim)
|
||||||
|
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
|
||||||
|
return lm_head, hidden_state, labels, vocab_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
|
||||||
|
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
|
||||||
|
lm_loss = get_causal_lm_loss()
|
||||||
|
|
||||||
|
logits = lm_head(hidden_state)
|
||||||
|
|
||||||
|
chunked_lm_loss = lm_loss(logits, labels)
|
||||||
|
|
||||||
|
logits_flattened = logits.view(-1, vocab_size)
|
||||||
|
labels_flattened = labels.view(-1)
|
||||||
|
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
logits_flattened.float(), labels_flattened, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)
|
||||||
Reference in New Issue
Block a user