diff --git a/scripts/finetune.py b/scripts/finetune.py index ddf1992d6..70b805ecd 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -231,7 +231,7 @@ def train( cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len, - seed=cfg.seed, + seed=cfg.seed or 42, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0b6f0a92a..85a32e914 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -91,12 +91,14 @@ class AxolotlTrainer(Trainer): if self.args.sample_packing: train_sampler = self._get_train_sampler() - return MultipackDistributedDataloader( - self.train_dataset, - batch_size=self._train_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=train_sampler, + return self.accelerator.prepare( + MultipackDistributedDataloader( + self.train_dataset, + batch_size=self._train_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=train_sampler, + ) ) return super().get_train_dataloader() @@ -157,7 +159,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset, num_replicas=1, rank=0, - seed=cfg.seed, + seed=cfg.seed or 42, ) data_loader = MultipackDistributedDataloader( train_dataset, @@ -170,12 +172,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ), sampler=sampler, ) + data_loader_len = len(data_loader) + LOG.info(f"data_loader_len: {data_loader_len}") total_num_steps = int( math.ceil( - len(data_loader) - * cfg.micro_batch_size - * cfg.num_epochs - / cfg.batch_size + data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size ) ) else: @@ -262,8 +263,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - max_steps=total_num_steps - * cfg.num_epochs, # this is helpful in case we don't actually know total # of steps + max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None