weighted CE losses

This commit is contained in:
Wing Lian
2023-08-02 15:57:00 -04:00
parent 83f7362480
commit c6cc54c7d9
3 changed files with 63 additions and 2 deletions

View File

@@ -8,7 +8,14 @@ import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

View File

@@ -197,7 +197,7 @@ class MultipackDistributedDataloader:
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn(concatenated)
yield self.collate_fn([concatenated])
len_remaining -= 1
if not len_remaining:
return

View File

@@ -33,6 +33,52 @@ from axolotl.utils.schedulers import (
LOG = logging.getLogger("axolotl")
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
):
return (
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
).sum()
@torch.jit.script
def create_weighted_mask(labels: torch.Tensor):
mask = labels != -100
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels).int()
curr_group_id = 0
for i in range(1, len(labels)):
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[i] = (
curr_group_id if mask[i] else 0
) # assign group id if unmasked label
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
mask_weights = torch.zeros_like(labels).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
return mask_weights
@torch.jit.script
def trainer_weighted_loss(model_output, labels, shift_labels=True):
logits = (
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
)
if shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
weights = create_weighted_mask(labels)
return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
@@ -137,6 +183,14 @@ class AxolotlTrainer(Trainer):
)
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
if self.args.sample_packing:
labels = inputs.pop("labels")
outputs = model(**inputs)
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""