add position_ids back

This commit is contained in:
Wing Lian
2023-07-18 01:50:41 -04:00
parent 3aba4c5d7c
commit 762f1b08db
2 changed files with 21 additions and 2 deletions

View File

@@ -81,6 +81,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [], "input_ids": [],
"attention_mask": [], "attention_mask": [],
"labels": [], "labels": [],
"position_ids": [],
} }
buffer_len = 0 buffer_len = 0
for dataset in self.datasets: for dataset in self.datasets:
@@ -112,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length : self.seq_length
] ]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and ( if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size() attention_mask.size() == input_ids.size()
@@ -120,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids,
} }
else: else:
LOG.warning( LOG.warning(
@@ -129,6 +134,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [], "input_ids": [],
"attention_mask": [], "attention_mask": [],
"labels": [], "labels": [],
"position_ids": [],
} }
buffer_len = 0 buffer_len = 0
idx = 1 idx = 1
@@ -155,8 +161,12 @@ class ConstantLengthDataset(IterableDataset):
labels_with_concat = torch.tensor( labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype labels, dtype=self.tokens_dtype
) )
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat) buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat) buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat) buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids) buffer_len += len(input_ids)

View File

@@ -22,6 +22,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
InterpolatingLogScheduler, InterpolatingLogScheduler,
@@ -89,6 +90,7 @@ class AxolotlTrainer(Trainer):
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: if self.args.sample_packing:
train_sampler = self._get_train_sampler() train_sampler = self._get_train_sampler()
return MultipackDistributedDataloader( return MultipackDistributedDataloader(
self.train_dataset, self.train_dataset,
batch_size=self._train_batch_size, batch_size=self._train_batch_size,
@@ -142,8 +144,15 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler return self.lr_scheduler
def add_position_ids(sample):
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
return sample
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids)
eval_dataset = eval_dataset.map(add_position_ids)
sampler = DistributedSampler( sampler = DistributedSampler(
train_dataset, train_dataset,
num_replicas=1, num_replicas=1,
@@ -154,7 +163,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset, train_dataset,
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=transformers.DataCollatorForSeq2Seq( collate_fn=DataCollatorForSeq2Seq(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
padding="longest", padding="longest",
@@ -412,7 +421,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
args=training_args, args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq( data_collator=DataCollatorForSeq2Seq(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,