From 762f1b08db051fa64a5fa1696cf7a6d7e4eb9a0f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 18 Jul 2023 01:50:41 -0400 Subject: [PATCH] add position_ids back --- src/axolotl/datasets.py | 10 ++++++++++ src/axolotl/utils/trainer.py | 13 +++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index bc137d238..4376fb18a 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -81,6 +81,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], + "position_ids": [], } buffer_len = 0 for dataset in self.datasets: @@ -112,6 +113,9 @@ class ConstantLengthDataset(IterableDataset): attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ : self.seq_length ] + position_ids = torch.cat(buffer["position_ids"], dim=-1)[ + : self.seq_length + ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] if labels.size() == input_ids.size() and ( attention_mask.size() == input_ids.size() @@ -120,6 +124,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, + "position_ids": position_ids, } else: LOG.warning( @@ -129,6 +134,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], + "position_ids": [], } buffer_len = 0 idx = 1 @@ -155,8 +161,12 @@ class ConstantLengthDataset(IterableDataset): labels_with_concat = torch.tensor( labels, dtype=self.tokens_dtype ) + position_ids = torch.arange( + len(input_ids), dtype=self.tokens_dtype + ) buffer["input_ids"].append(input_ids_with_concat) buffer["attention_mask"].append(attention_mask_with_concat) buffer["labels"].append(labels_with_concat) + buffer["position_ids"].append(position_ids) buffer_len += len(input_ids) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 38d5f0e3b..0b6f0a92a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -22,6 +22,7 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, SavePeftModelCallback, ) +from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.schedulers import ( InterpolatingLogScheduler, @@ -89,6 +90,7 @@ class AxolotlTrainer(Trainer): def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: train_sampler = self._get_train_sampler() + return MultipackDistributedDataloader( self.train_dataset, batch_size=self._train_batch_size, @@ -142,8 +144,15 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): return self.lr_scheduler +def add_position_ids(sample): + sample["position_ids"] = torch.arange(len(sample["input_ids"])) + return sample + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.sample_packing: + train_dataset = train_dataset.map(add_position_ids) + eval_dataset = eval_dataset.map(add_position_ids) sampler = DistributedSampler( train_dataset, num_replicas=1, @@ -154,7 +163,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset, batch_size=cfg.micro_batch_size, seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, - collate_fn=transformers.DataCollatorForSeq2Seq( + collate_fn=DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", padding="longest", @@ -412,7 +421,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, - data_collator=transformers.DataCollatorForSeq2Seq( + data_collator=DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", **data_collator_kwargs,