add optimization for group-by-len (#563)
This commit is contained in:
@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
sample_len = len(sample["input_ids"])
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
sample["length"] = sample_len
|
||||
return sample
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
|
||||
if cfg.group_by_length:
|
||||
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
||||
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||
if eval_dataset:
|
||||
|
||||
Reference in New Issue
Block a user