pass sample packing efficiency to training args

This commit is contained in:
Wing Lian
2023-07-26 00:06:28 -04:00
parent 2bee646e85
commit 487abfc769

View File

@@ -47,6 +47,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
@@ -109,6 +113,7 @@ class AxolotlTrainer(Trainer):
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
)
return super().get_train_dataloader()
@@ -310,6 +315,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len,