add optimization for group-by-len (#563)

This commit is contained in:
Wing Lian
2023-09-13 10:57:12 -04:00
committed by GitHub
parent fdb777bc06
commit e5bb22a56b

View File

@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
sample["length"] = sample_len
return sample
def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.group_by_length:
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
if eval_dataset: