diff --git a/docs/config.qmd b/docs/config.qmd index eba9f4881..f03c9da06 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -529,7 +529,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) -# Save model as safetensors (require safetensors package) +# Save model as safetensors (require safetensors package). Transformers default True save_safetensors: # Whether to mask out or include the human's prompt from the training labels diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6fb36f299..c9cc2070e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -252,14 +252,15 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["logging_steps"] = logging_steps + # precision training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False training_args_kwargs["tf32"] = self.cfg.tf32 - if self.cfg.bf16 == "full": training_args_kwargs["bf16_full_eval"] = True else: training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16 + # hub if self.cfg.hub_model_id: training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id training_args_kwargs["push_to_hub"] = True @@ -269,10 +270,7 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.hub_strategy: training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy - if self.cfg.save_safetensors is not None: - training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors - - # set save_strategy and save_steps + # save_strategy and save_steps if self.cfg.save_steps: training_args_kwargs["save_strategy"] = "steps" training_args_kwargs["save_steps"] = self.cfg.save_steps @@ -282,7 +280,15 @@ class TrainerBuilderBase(abc.ABC): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" - training_args_kwargs["save_only_model"] = self.cfg.save_only_model + # eval_strategy and eval_steps + if not self.eval_dataset or self.cfg.val_set_size == 0: + # do not eval if no eval_dataset or val_set_size=0 + training_args_kwargs["eval_strategy"] = "no" + elif self.cfg.eval_steps: + training_args_kwargs["eval_strategy"] = "steps" + training_args_kwargs["eval_steps"] = self.cfg.eval_steps + elif self.cfg.eval_strategy: + training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy if self.cfg.gradient_checkpointing: training_args_kwargs[ @@ -297,6 +303,7 @@ class TrainerBuilderBase(abc.ABC): "use_reentrant": False } + # set arg into trainer_args_kwargs with same name if value not None for arg in [ "adam_beta1", "adam_beta2", @@ -305,6 +312,11 @@ class TrainerBuilderBase(abc.ABC): "dataloader_num_workers", "dataloader_pin_memory", "dataloader_prefetch_factor", + "gradient_accumulation_steps", + "learning_rate", + "output_dir", + "save_safetensors", + "save_only_model", "include_tokens_per_second", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: @@ -317,12 +329,6 @@ class TrainerBuilderBase(abc.ABC): "per_device_eval_batch_size" ] = self.cfg.eval_batch_size - training_args_kwargs[ - "gradient_accumulation_steps" - ] = self.cfg.gradient_accumulation_steps - - training_args_kwargs["learning_rate"] = self.cfg.learning_rate - training_args_kwargs["output_dir"] = self.cfg.output_dir training_args_kwargs["save_total_limit"] = ( self.cfg.save_total_limit if self.cfg.save_total_limit else 4 ) @@ -331,6 +337,11 @@ class TrainerBuilderBase(abc.ABC): total_num_steps if self.cfg.max_steps else -1 ) + # max_length is not used in CausalTrainer + if self.cfg.reward_model or self.cfg.rl: + training_args_kwargs["max_length"] = self.cfg.sequence_len + + # reporting report_to = [] if self.cfg.use_wandb: report_to.append("wandb") @@ -349,6 +360,24 @@ class TrainerBuilderBase(abc.ABC): else: training_args_kwargs["run_name"] = None + # optim/scheduler + training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding + if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: + training_args_kwargs["lr_scheduler_type"] = "cosine" + training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler + else: + training_args_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" + ) + training_args_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) + training_args_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio + training_args_kwargs[ + "cosine_constant_lr_ratio" + ] = self.cfg.cosine_constant_lr_ratio + return training_args_kwargs @@ -476,18 +505,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.remove_unused_columns ) - if not self.cfg.test_datasets and self.cfg.val_set_size == 0: - # no eval set, so don't eval - training_arguments_kwargs["eval_strategy"] = "no" - elif self.cfg.eval_steps: - training_arguments_kwargs["eval_strategy"] = "steps" - training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps - elif self.cfg.eval_strategy: - training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy - else: - # we have an eval set, but no steps defined, default to use epoch - training_arguments_kwargs["eval_strategy"] = "epoch" - if self.cfg.do_bench_eval: training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval if self.cfg.bench_dataset: @@ -582,30 +599,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "optim_target_modules" ] = self.cfg.optim_target_modules - training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio - training_arguments_kwargs[ - "loraplus_lr_embedding" - ] = self.cfg.loraplus_lr_embedding training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups - if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: - training_arguments_kwargs["lr_scheduler_type"] = "cosine" - training_arguments_kwargs["alternate_lr_scheduler_type"] = ( - self.cfg.lr_scheduler - ) - else: - training_arguments_kwargs["lr_scheduler_type"] = ( - self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" - ) - training_arguments_kwargs["lr_scheduler_kwargs"] = ( - self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} - ) - training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio - training_arguments_kwargs["cosine_constant_lr_ratio"] = ( - self.cfg.cosine_constant_lr_ratio - ) training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) @@ -671,9 +668,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} - if self.cfg.reward_model: - training_arguments_kwargs["max_length"] = self.cfg.sequence_len - # Handle custom optimizer custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] if self.cfg.optimizer in custom_supported_optimizers: @@ -1006,22 +1000,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): total_num_steps=total_num_steps ) - if not self.eval_dataset: - training_args_kwargs["eval_strategy"] = "no" - elif self.cfg.eval_steps: - training_args_kwargs["eval_strategy"] = "steps" - training_args_kwargs["eval_steps"] = self.cfg.eval_steps - elif self.cfg.eval_strategy: - training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy - - training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio - training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding training_args_kwargs["lr_scheduler_type"] = ( self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" ) - training_args_kwargs["lr_scheduler_kwargs"] = ( - self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} - ) if self.cfg.remove_unused_columns is not None: training_args_kwargs["remove_unused_columns"] = ( @@ -1056,14 +1037,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl is RLType.SIMPO: training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" - training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig - training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len @@ -1077,7 +1056,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len @@ -1090,7 +1068,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = AxolotlDPOConfig if self.cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" - training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb