Compare commits
4 Commits
no-seq-len
...
unsloth_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f9b172c47 | ||
|
|
8671ed5a0c | ||
|
|
538c004080 | ||
|
|
add3b139ed |
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Adapted from Unsloth
|
||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
MAX_FUSED_SIZE = 65536
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
Pi = exp(xi) / sum(exp(xi))
|
||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||
= -y [ x - log[sum(exp(x))] ]
|
||||
= y * (log[sum(exp(x))] - x)
|
||||
If y == 0: CE_i = 0
|
||||
If y == 1: CE_i = logsumexp - x
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
loss_ptr += row_idx
|
||||
lse_ptr += row_idx
|
||||
labels_ptr += row_idx
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
max_logits = tl.max(logits, 0)
|
||||
# Maximum stops overflow
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr, lse)
|
||||
|
||||
if label_idx != -100:
|
||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = lse - logits_label
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||
|
||||
From https://en.wikipedia.org/wiki/LogSumExp
|
||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||
|
||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||
|
||||
If y == 0: dC/dx = 0
|
||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
dloss_ptr += row_idx * dloss_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||
|
||||
if label_idx != -100:
|
||||
dloss = tl.load(dloss_ptr)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
return losses
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dlosses):
|
||||
logits, logsumexp, labels = ctx.saved_tensors
|
||||
n_rows, n_cols = logits.shape
|
||||
|
||||
_cross_entropy_backward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
labels: (batch, seq_len,)
|
||||
Returns:
|
||||
losses: float
|
||||
"""
|
||||
batch, seq_len, d = logits.shape
|
||||
assert(labels.shape == (batch, seq_len))
|
||||
|
||||
loss = CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
pass
|
||||
@@ -13,16 +13,20 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
@@ -36,6 +40,9 @@ def replace_mistral_attn_with_flash_attn(
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
||||
mistral_causallm_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
@@ -641,3 +648,71 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def mistral_causallm_forward(
|
||||
self: OriginalMistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
# FAST CROSS ENTROPY
|
||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
Reference in New Issue
Block a user