diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 25d327dcd..5ab62343a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -230,6 +230,101 @@ class TrainerBuilderBase(abc.ABC): # TODO return trainer + def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: + training_args_kwargs = {} + + warmup_steps = None + if self.cfg.warmup_steps is not None: + warmup_steps = self.cfg.warmup_steps + elif self.cfg.warmup_ratio is not None: + warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) + else: + warmup_steps = min(int(0.03 * total_num_steps), 100) + if warmup_steps == 1: + warmup_steps = 2 + + logging_steps = ( + self.cfg.logging_steps + if self.cfg.logging_steps is not None + else max(min(int(0.005 * total_num_steps), 10), 1) + ) + + training_args_kwargs["warmup_steps"] = warmup_steps + training_args_kwargs["logging_steps"] = logging_steps + + if self.cfg.hub_model_id: + training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_args_kwargs["push_to_hub"] = True + training_args_kwargs["hub_private_repo"] = True + training_args_kwargs["hub_always_push"] = True + + if self.cfg.hub_strategy: + training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + + if self.cfg.save_safetensors is not None: + training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors + + # set save_strategy and save_steps + if self.cfg.save_steps: + training_args_kwargs["save_strategy"] = "steps" + training_args_kwargs["save_steps"] = self.cfg.save_steps + elif self.cfg.save_strategy: + training_args_kwargs["save_strategy"] = self.cfg.save_strategy + else: + # default to saving each epoch if not defined + training_args_kwargs["save_strategy"] = "epoch" + + training_args_kwargs["save_only_model"] = self.cfg.save_only_model + + if self.cfg.gradient_checkpointing: + training_args_kwargs[ + "gradient_checkpointing" + ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs is not None: + training_args_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_args_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } + + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "max_grad_norm", + "dataloader_num_workers", + "dataloader_pin_memory", + "dataloader_prefetch_factor", + "include_tokens_per_second", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + + training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size + + if self.cfg.eval_batch_size: + training_args_kwargs[ + "per_device_eval_batch_size" + ] = self.cfg.eval_batch_size + + training_args_kwargs[ + "gradient_accumulation_steps" + ] = self.cfg.gradient_accumulation_steps + + training_args_kwargs["learning_rate"] = self.cfg.learning_rate + training_args_kwargs["output_dir"] = self.cfg.output_dir + training_args_kwargs["save_total_limit"] = ( + self.cfg.save_total_limit if self.cfg.save_total_limit else 4 + ) + + training_args_kwargs["max_steps"] = ( + total_num_steps if self.cfg.max_steps else -1 + ) + + return training_args_kwargs + class HFCausalTrainerBuilder(TrainerBuilderBase): """ @@ -319,29 +414,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlTrainer def build(self, total_num_steps): - warmup_steps = None - if self.cfg.warmup_steps is not None: - warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio is not None: - warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) - else: - warmup_steps = min(int(0.03 * total_num_steps), 100) - if warmup_steps == 1: - warmup_steps = 2 - - logging_steps = ( - self.cfg.logging_steps - if self.cfg.logging_steps is not None - else max(min(int(0.005 * total_num_steps), 10), 1) - ) - - training_arguments_kwargs = {} - - if self.cfg.include_tokens_per_second is not None: - training_arguments_kwargs["include_tokens_per_second"] = ( - self.cfg.include_tokens_per_second - ) - + training_arguments_kwargs = self._set_base_training_args(total_num_steps) if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: @@ -350,20 +423,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.fp16 and not self.cfg.bf16 ) or False training_arguments_kwargs["tf32"] = self.cfg.tf32 - training_arguments_kwargs["warmup_steps"] = warmup_steps - training_arguments_kwargs["logging_steps"] = logging_steps if self.cfg.seed is not None: training_arguments_kwargs["seed"] = self.cfg.seed - if self.cfg.gradient_checkpointing: - training_arguments_kwargs["gradient_checkpointing"] = ( - self.cfg.gradient_checkpointing - ) - if self.cfg.gradient_checkpointing_kwargs is not None: - training_arguments_kwargs["gradient_checkpointing_kwargs"] = ( - self.cfg.gradient_checkpointing_kwargs - ) if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: @@ -383,39 +446,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.lr_quadratic_warmup ) - if self.cfg.adam_beta1: - training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 - if self.cfg.adam_beta2: - training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 - if self.cfg.adam_epsilon: - training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon - if self.cfg.max_grad_norm: - training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm - - if self.cfg.hub_model_id: - training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id - training_arguments_kwargs["push_to_hub"] = True - training_arguments_kwargs["hub_private_repo"] = True - training_arguments_kwargs["hub_always_push"] = True - - if self.cfg.hub_strategy: - training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy - - if self.cfg.save_safetensors is not None: - training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors - - if self.cfg.dataloader_pin_memory is not None: - training_arguments_kwargs["dataloader_pin_memory"] = ( - self.cfg.dataloader_pin_memory - ) - if self.cfg.dataloader_num_workers is not None: - training_arguments_kwargs["dataloader_num_workers"] = ( - self.cfg.dataloader_num_workers - ) - if self.cfg.dataloader_prefetch_factor is not None: - training_arguments_kwargs["dataloader_prefetch_factor"] = ( - self.cfg.dataloader_prefetch_factor - ) if self.cfg.dataloader_drop_last is not None: training_arguments_kwargs["dataloader_drop_last"] = ( self.cfg.dataloader_drop_last @@ -440,17 +470,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # we have an eval set, but no steps defined, default to use epoch training_arguments_kwargs["eval_strategy"] = "epoch" - if self.cfg.save_steps: - training_arguments_kwargs["save_strategy"] = "steps" - training_arguments_kwargs["save_steps"] = self.cfg.save_steps - elif self.cfg.save_strategy: - training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy - else: - # default to saving each epoch if not defined - training_arguments_kwargs["save_strategy"] = "epoch" - - training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model - if self.cfg.do_bench_eval: training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval if self.cfg.bench_dataset: @@ -493,33 +512,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) # these are all the "standard" kwargs that are def used - training_arguments_kwargs["max_steps"] = ( - self.cfg.max_steps if self.cfg.max_steps else -1 - ) training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len - training_arguments_kwargs["per_device_train_batch_size"] = ( - self.cfg.micro_batch_size - ) - if self.cfg.eval_batch_size: - training_arguments_kwargs["per_device_eval_batch_size"] = ( - self.cfg.eval_batch_size - ) + if self.cfg.auto_find_batch_size is not None: - training_arguments_kwargs["auto_find_batch_size"] = ( - self.cfg.auto_find_batch_size - ) - training_arguments_kwargs["gradient_accumulation_steps"] = ( - self.cfg.gradient_accumulation_steps - ) - training_arguments_kwargs["eval_accumulation_steps"] = ( - self.cfg.gradient_accumulation_steps - ) + training_arguments_kwargs[ + "auto_find_batch_size" + ] = self.cfg.auto_find_batch_size + + training_arguments_kwargs[ + "eval_accumulation_steps" + ] = self.cfg.gradient_accumulation_steps training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs - training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate - training_arguments_kwargs["output_dir"] = self.cfg.output_dir - training_arguments_kwargs["save_total_limit"] = ( - self.cfg.save_total_limit if self.cfg.save_total_limit else 4 - ) + training_arguments_kwargs["load_best_model_at_end"] = ( ( self.cfg.load_best_model_at_end is not False @@ -974,34 +978,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase): return callbacks def build_training_arguments(self, total_num_steps): - training_args_kwargs = {} - for arg in [ - "adam_beta1", - "adam_beta2", - "adam_epsilon", - "dataloader_num_workers", - "dataloader_pin_memory", - ]: - if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: - training_args_kwargs[arg] = getattr(self.cfg, arg) + training_args_kwargs = self._set_base_training_args( + total_num_steps=total_num_steps + ) - if self.cfg.hub_model_id: - training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id - training_args_kwargs["push_to_hub"] = True - training_args_kwargs["hub_private_repo"] = True - training_args_kwargs["hub_always_push"] = True - - if self.cfg.hub_strategy: - training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy - - if self.cfg.save_safetensors is not None: - training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors - - if self.eval_dataset: + if not self.eval_dataset: + training_args_kwargs["eval_strategy"] = "no" + elif self.cfg.eval_steps: training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps - else: - training_args_kwargs["eval_strategy"] = "no" + elif self.cfg.eval_strategy: + training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True @@ -1014,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["lr_scheduler_kwargs"] = ( self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) + if self.cfg.remove_unused_columns is not None: training_args_kwargs["remove_unused_columns"] = ( self.cfg.remove_unused_columns @@ -1021,47 +1009,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: training_args_kwargs["remove_unused_columns"] = False - if self.cfg.dataloader_pin_memory is not None: - training_args_kwargs["dataloader_pin_memory"] = ( - self.cfg.dataloader_pin_memory - ) - if self.cfg.dataloader_num_workers is not None: - training_args_kwargs["dataloader_num_workers"] = ( - self.cfg.dataloader_num_workers - ) - if self.cfg.dataloader_prefetch_factor is not None: - training_args_kwargs["dataloader_prefetch_factor"] = ( - self.cfg.dataloader_prefetch_factor - ) - - if self.cfg.seed is not None: - training_args_kwargs["seed"] = self.cfg.seed - - if self.cfg.gradient_checkpointing: - training_args_kwargs["gradient_checkpointing"] = ( - self.cfg.gradient_checkpointing - ) - if self.cfg.gradient_checkpointing_kwargs is not None: - training_args_kwargs["gradient_checkpointing_kwargs"] = ( - self.cfg.gradient_checkpointing_kwargs - ) - else: - training_args_kwargs["gradient_checkpointing_kwargs"] = { - "use_reentrant": False - } - - # set save_strategy and save_steps - if self.cfg.save_steps: - training_args_kwargs["save_strategy"] = "steps" - training_args_kwargs["save_steps"] = self.cfg.save_steps - elif self.cfg.save_strategy: - training_args_kwargs["save_strategy"] = self.cfg.save_strategy - else: - # default to saving each epoch if not defined - training_args_kwargs["save_strategy"] = "epoch" - - training_args_kwargs["save_only_model"] = self.cfg.save_only_model - if self.cfg.dataset_processes: training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes @@ -1137,19 +1084,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] - max_steps = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg - self.cfg.output_dir, - per_device_train_batch_size=self.cfg.micro_batch_size, - max_steps=max_steps, - gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, - learning_rate=self.cfg.learning_rate, - warmup_steps=self.cfg.warmup_steps, logging_first_step=True, - logging_steps=1, optim=self.cfg.optimizer, - save_total_limit=self.cfg.save_total_limit or 5, **training_args_kwargs, )