weighted CE losses
This commit is contained in:
@@ -8,7 +8,14 @@ import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ class MultipackDistributedDataloader:
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn(concatenated)
|
||||
yield self.collate_fn([concatenated])
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
@@ -33,6 +33,52 @@ from axolotl.utils.schedulers import (
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weighted_cross_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||
):
|
||||
return (
|
||||
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||
).sum()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def create_weighted_mask(labels: torch.Tensor):
|
||||
mask = labels != -100
|
||||
|
||||
# Create a tensor to track group ids
|
||||
group_ids = torch.zeros_like(labels).int()
|
||||
curr_group_id = 0
|
||||
|
||||
for i in range(1, len(labels)):
|
||||
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
|
||||
curr_group_id += 1 # start new group
|
||||
group_ids[i] = (
|
||||
curr_group_id if mask[i] else 0
|
||||
) # assign group id if unmasked label
|
||||
|
||||
# Count only unmasked labels in each group
|
||||
group_counts = torch.bincount(group_ids[mask])
|
||||
|
||||
mask_weights = torch.zeros_like(labels).float()
|
||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||
|
||||
return mask_weights
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
logits = (
|
||||
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
)
|
||||
if shift_labels:
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
|
||||
weights = create_weighted_mask(labels)
|
||||
return weighted_cross_entropy(logits, labels, weights)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
@@ -137,6 +183,14 @@ class AxolotlTrainer(Trainer):
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
if self.args.sample_packing:
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user