need to pass total num tokens to trainer too
This commit is contained in:
@@ -122,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
train_data_total_num_tokens: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the total number of tokens in the train dataset"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -182,6 +186,7 @@ class AxolotlTrainer(Trainer):
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
total_num_tokens=self.args.train_data_total_num_tokens,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
@@ -204,6 +209,7 @@ class AxolotlTrainer(Trainer):
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
total_num_tokens=None,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
@@ -468,6 +474,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
|
||||
train_data_total_num_tokens=cfg.total_num_tokens,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user