diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 073786882..600c5ad54 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -8,7 +8,14 @@ import torch import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, + ) + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 2f2b0b372..f1ab86e37 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -197,7 +197,7 @@ class MultipackDistributedDataloader: # } # chunked_data.append(chunk) # yield self.collate_fn(chunked_data) - yield self.collate_fn(concatenated) + yield self.collate_fn([concatenated]) len_remaining -= 1 if not len_remaining: return diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e8a14df1f..d4245563e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -33,6 +33,52 @@ from axolotl.utils.schedulers import ( LOG = logging.getLogger("axolotl") +@torch.jit.script +def weighted_cross_entropy( + logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor +): + return ( + weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none") + ).sum() + + +@torch.jit.script +def create_weighted_mask(labels: torch.Tensor): + mask = labels != -100 + + # Create a tensor to track group ids + group_ids = torch.zeros_like(labels).int() + curr_group_id = 0 + + for i in range(1, len(labels)): + if mask[i] and not mask[i - 1]: # switch from masked to unmasked label + curr_group_id += 1 # start new group + group_ids[i] = ( + curr_group_id if mask[i] else 0 + ) # assign group id if unmasked label + + # Count only unmasked labels in each group + group_counts = torch.bincount(group_ids[mask]) + + mask_weights = torch.zeros_like(labels).float() + mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] + + return mask_weights + + +@torch.jit.script +def trainer_weighted_loss(model_output, labels, shift_labels=True): + logits = ( + model_output["logits"] if isinstance(model_output, dict) else model_output[0] + ) + if shift_labels: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + + weights = create_weighted_mask(labels) + return weighted_cross_entropy(logits, labels, weights) + + @dataclass class AxolotlTrainingArguments(TrainingArguments): """ @@ -137,6 +183,14 @@ class AxolotlTrainer(Trainer): ) return super().get_eval_dataloader(eval_dataset) + def compute_loss(self, model, inputs, return_outputs=False): + if self.args.sample_packing: + labels = inputs.pop("labels") + outputs = model(**inputs) + loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + return (loss, outputs) if return_outputs else loss + return super().compute_loss(model, inputs, return_outputs=return_outputs) + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """