From 98c9bc69de3a8c846cb7d48536469b73017bc160 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Aug 2023 23:20:19 -0400 Subject: [PATCH] seq_len_multiple for packing --- src/axolotl/utils/dataloader.py | 16 ++++++---- src/axolotl/utils/trainer.py | 54 ++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index b12ea338c..394afe0e5 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -128,6 +128,7 @@ class MultipackDistributedDataloader: batch_size: int = 1, sampler: Union[Sampler, DistributedSampler] = None, packing_efficiency_estimate: float = 1.0, + seq_len_multiple: int = 1, ): # Dataset self.dataset = dataset @@ -135,9 +136,10 @@ class MultipackDistributedDataloader: [len(sample["input_ids"]) for sample in self.dataset] ) assert isinstance(self.lengths, np.ndarray) - + assert batch_size % seq_len_multiple == 0 self.sampler = sampler self.batch_size = batch_size + self.seq_len_multiple = seq_len_multiple self.seq_max_length = seq_max_length self.batch_max_length = batch_size * seq_max_length self.collate_fn = collate_fn @@ -148,7 +150,7 @@ class MultipackDistributedDataloader: # statistics self.eff_total_used = 0 self.eff_total_slots = 0 - self.packing_efficiency_estimate = packing_efficiency_estimate + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 def generate_batches(self, set_stats=False): if self.sampler: @@ -164,7 +166,7 @@ class MultipackDistributedDataloader: lengths_cumsum=lengths_cumsum, rank=self.rank, # c=self.batch_max_length, - c=self.seq_max_length, + c=self.seq_max_length * self.seq_len_multiple, n=self.num_replicas, ) @@ -181,18 +183,20 @@ class MultipackDistributedDataloader: all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est() - for batches in chunk(all_batches, self.batch_size): + for batches in chunk(all_batches, self.batch_size / self.seq_len_multiple): chunked_data = [] + attn_mask_cum_idx = 0 for batch in batches: concatenated = {} batched_data = [self.dataset[batch_idx] for batch_idx in batch] for feature in features: if feature == "attention_mask": arrays = [ - (idx + 1) * np.array(item[feature]) + (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) for idx, item in enumerate(batched_data) if feature in item ] + attn_mask_cum_idx += len(batched_data) concatenated[feature] = np.concatenate(arrays) else: arrays = [ @@ -216,7 +220,7 @@ class MultipackDistributedDataloader: # } # chunked_data.append(chunk) # yield self.collate_fn(chunked_data) - yield self.collate_fn([chunked_data]) + yield self.collate_fn(chunked_data) len_remaining -= 1 if not len_remaining: return diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2dedfab31..1cc6ba57b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -53,26 +53,34 @@ def weighted_cross_entropy( @torch.jit.script def create_weighted_mask(labels: torch.Tensor): - mask = labels != -100 + # Check if the tensor is 2D. If not, unsqueeze it to make it 2D + if len(labels.shape) == 1: + labels = labels.unsqueeze(0) - # Create a tensor to track group ids - group_ids = torch.zeros_like(labels).int() - curr_group_id = 0 + weights = torch.zeros_like(labels).float() + for i in range(labels.shape[0]): + mask = labels[i] != -100 - for i in range(1, len(labels)): - if mask[i] and not mask[i - 1]: # switch from masked to unmasked label - curr_group_id += 1 # start new group - group_ids[i] = ( - curr_group_id if mask[i] else 0 - ) # assign group id if unmasked label + # Create a tensor to track group ids + group_ids = torch.zeros_like(labels[i]).int() + curr_group_id = 0 - # Count only unmasked labels in each group - group_counts = torch.bincount(group_ids[mask]) + for j in range(1, len(labels[i])): + if mask[j] and not mask[j - 1]: # switch from masked to unmasked label + curr_group_id += 1 # start new group + group_ids[j] = ( + curr_group_id if mask[j] else 0 + ) # assign group id if unmasked label - mask_weights = torch.zeros_like(labels).float() - mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] + # Count only unmasked labels in each group + group_counts = torch.bincount(group_ids[mask]) - return mask_weights + mask_weights = torch.zeros_like(labels[i]).float() + mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] + + weights[i] = mask_weights + + return weights.squeeze() # squeeze the output to match the input dimension def trainer_weighted_loss(model_output, labels, shift_labels=True): @@ -168,6 +176,7 @@ class AxolotlTrainer(Trainer): collate_fn=self.data_collator, sampler=train_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, + seq_len_multiple=2, ) ) return super().get_train_dataloader() @@ -187,16 +196,18 @@ class AxolotlTrainer(Trainer): seq_max_length=self.args.max_seq_length, collate_fn=self.data_collator, sampler=eval_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + seq_len_multiple=2, ) ) return super().get_eval_dataloader(eval_dataset) def compute_loss(self, model, inputs, return_outputs=False): - if self.args.sample_packing: - labels = inputs.pop("labels") - outputs = model(**inputs) - loss = trainer_weighted_loss(outputs, labels, shift_labels=True) - return (loss, outputs) if return_outputs else loss + # if self.args.sample_packing: + # labels = inputs.pop("labels") + # outputs = model(**inputs) + # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) @@ -262,7 +273,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): * total_num_tokens / cfg.sample_packing_eff_est / 2048 - / cfg.batch_size + // cfg.batch_size ) - 1 ) @@ -284,6 +295,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ), sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, + seq_len_multiple=2, ) data_loader_len = len(data_loader) LOG.info(f"data_loader_len: {data_loader_len}")