From 79500f358a871eb75d32542b8b83614de8226921 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 10 Aug 2023 19:08:23 -0400 Subject: [PATCH] need to pass total num tokens to trainer too --- src/axolotl/utils/trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 82380183c..3eed07b4c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -122,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=1, metadata={"help": "the multiplier for the max len for packed sequences"}, ) + train_data_total_num_tokens: Optional[int] = field( + default=None, + metadata={"help": "the total number of tokens in the train dataset"}, + ) class AxolotlTrainer(Trainer): @@ -182,6 +186,7 @@ class AxolotlTrainer(Trainer): packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, device_count=int(os.environ.get("WORLD_SIZE", 1)), + total_num_tokens=self.args.train_data_total_num_tokens, ) ) return super().get_train_dataloader() @@ -204,6 +209,7 @@ class AxolotlTrainer(Trainer): packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.eval_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + total_num_tokens=None, ) ) return super().get_eval_dataloader(eval_dataset) @@ -468,6 +474,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1, + train_data_total_num_tokens=cfg.total_num_tokens, **training_arguments_kwargs, )