add position_ids back

This commit is contained in:
Wing Lian
2023-07-18 01:50:41 -04:00
parent 3aba4c5d7c
commit 762f1b08db
2 changed files with 21 additions and 2 deletions

View File

@@ -81,6 +81,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
@@ -112,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -120,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
@@ -129,6 +134,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
@@ -155,8 +161,12 @@ class ConstantLengthDataset(IterableDataset):
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -22,6 +22,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
@@ -89,6 +90,7 @@ class AxolotlTrainer(Trainer):
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
@@ -142,8 +144,15 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler
def add_position_ids(sample):
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
return sample
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids)
eval_dataset = eval_dataset.map(add_position_ids)
sampler = DistributedSampler(
train_dataset,
num_replicas=1,
@@ -154,7 +163,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=transformers.DataCollatorForSeq2Seq(
collate_fn=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
@@ -412,7 +421,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
data_collator=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,