refactor so we can add test
This commit is contained in:
@@ -71,7 +71,7 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
|
||||
return total_loss / total_elements
|
||||
|
||||
|
||||
def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||
@@ -79,10 +79,8 @@ def build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
||||
return loss_fn_ce
|
||||
|
||||
|
||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
import transformers.loss.loss_utils
|
||||
|
||||
loss_fn_ce = build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||
|
||||
def chunked_fix_cross_entropy(
|
||||
source,
|
||||
@@ -103,7 +101,7 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
||||
def for_causal_lm_chunked_loss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int, # pylint: disable=unused-argument
|
||||
vocab_size: int = None, # pylint: disable=unused-argument
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
@@ -123,6 +121,13 @@ def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -10
|
||||
)
|
||||
return loss
|
||||
|
||||
return for_causal_lm_chunked_loss
|
||||
|
||||
|
||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
import transformers.loss.loss_utils
|
||||
|
||||
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||
for_causal_lm_chunked_loss
|
||||
|
||||
Reference in New Issue
Block a user