add position_ids back
This commit is contained in:
@@ -81,6 +81,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": [],
|
"input_ids": [],
|
||||||
"attention_mask": [],
|
"attention_mask": [],
|
||||||
"labels": [],
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
}
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
@@ -112,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||||
: self.seq_length
|
: self.seq_length
|
||||||
]
|
]
|
||||||
|
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||||
|
: self.seq_length
|
||||||
|
]
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
if labels.size() == input_ids.size() and (
|
if labels.size() == input_ids.size() and (
|
||||||
attention_mask.size() == input_ids.size()
|
attention_mask.size() == input_ids.size()
|
||||||
@@ -120,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -129,6 +134,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": [],
|
"input_ids": [],
|
||||||
"attention_mask": [],
|
"attention_mask": [],
|
||||||
"labels": [],
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
}
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
idx = 1
|
idx = 1
|
||||||
@@ -155,8 +161,12 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
labels_with_concat = torch.tensor(
|
labels_with_concat = torch.tensor(
|
||||||
labels, dtype=self.tokens_dtype
|
labels, dtype=self.tokens_dtype
|
||||||
)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
len(input_ids), dtype=self.tokens_dtype
|
||||||
|
)
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||||
buffer["labels"].append(labels_with_concat)
|
buffer["labels"].append(labels_with_concat)
|
||||||
|
buffer["position_ids"].append(position_ids)
|
||||||
buffer_len += len(input_ids)
|
buffer_len += len(input_ids)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
InterpolatingLogScheduler,
|
InterpolatingLogScheduler,
|
||||||
@@ -89,6 +90,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
|
|
||||||
return MultipackDistributedDataloader(
|
return MultipackDistributedDataloader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=self._train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
@@ -142,8 +144,15 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def add_position_ids(sample):
|
||||||
|
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
|
train_dataset = train_dataset.map(add_position_ids)
|
||||||
|
eval_dataset = eval_dataset.map(add_position_ids)
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
num_replicas=1,
|
num_replicas=1,
|
||||||
@@ -154,7 +163,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||||
collate_fn=transformers.DataCollatorForSeq2Seq(
|
collate_fn=DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding="longest",
|
padding="longest",
|
||||||
@@ -412,7 +421,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user