weighted CE losses
This commit is contained in:
@@ -8,7 +8,14 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class MultipackDistributedDataloader:
|
|||||||
# }
|
# }
|
||||||
# chunked_data.append(chunk)
|
# chunked_data.append(chunk)
|
||||||
# yield self.collate_fn(chunked_data)
|
# yield self.collate_fn(chunked_data)
|
||||||
yield self.collate_fn(concatenated)
|
yield self.collate_fn([concatenated])
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -33,6 +33,52 @@ from axolotl.utils.schedulers import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def weighted_cross_entropy(
|
||||||
|
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def create_weighted_mask(labels: torch.Tensor):
|
||||||
|
mask = labels != -100
|
||||||
|
|
||||||
|
# Create a tensor to track group ids
|
||||||
|
group_ids = torch.zeros_like(labels).int()
|
||||||
|
curr_group_id = 0
|
||||||
|
|
||||||
|
for i in range(1, len(labels)):
|
||||||
|
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
|
||||||
|
curr_group_id += 1 # start new group
|
||||||
|
group_ids[i] = (
|
||||||
|
curr_group_id if mask[i] else 0
|
||||||
|
) # assign group id if unmasked label
|
||||||
|
|
||||||
|
# Count only unmasked labels in each group
|
||||||
|
group_counts = torch.bincount(group_ids[mask])
|
||||||
|
|
||||||
|
mask_weights = torch.zeros_like(labels).float()
|
||||||
|
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||||
|
|
||||||
|
return mask_weights
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||||
|
logits = (
|
||||||
|
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||||
|
)
|
||||||
|
if shift_labels:
|
||||||
|
logits = logits[..., :-1, :].contiguous()
|
||||||
|
labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
weights = create_weighted_mask(labels)
|
||||||
|
return weighted_cross_entropy(logits, labels, weights)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
@@ -137,6 +183,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
if self.args.sample_packing:
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user