pass sample packing efficiency to training args
This commit is contained in:
@@ -47,6 +47,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
)
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=2048,
|
default=2048,
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
@@ -109,6 +113,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
seq_max_length=self.args.max_seq_length,
|
seq_max_length=self.args.max_seq_length,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -310,6 +315,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
|
if cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user