add optimization for group-by-len (#563)
This commit is contained in:
@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
|
|
||||||
def add_position_ids(sample):
|
def add_position_ids(sample):
|
||||||
|
sample_len = len(sample["input_ids"])
|
||||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||||
|
sample["length"] = sample_len
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def add_length(sample):
|
||||||
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||||
|
|
||||||
|
if cfg.group_by_length:
|
||||||
|
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
|
|||||||
Reference in New Issue
Block a user