diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a110c08ed..881114634 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -969,7 +969,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.dataset_processes: training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - if self.cfg.trl.beta or self.cfg.rl_beta: + if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta: training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index a58cfa19c..56415690b 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -33,25 +33,25 @@ class GRPOStrategy: grpo_args_kwargs = {} if cfg.trl and cfg.trl.use_vllm: grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm - if cfg.trl.vllm_device: + if cfg.trl and cfg.trl.vllm_device: grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device else: grpo_args_kwargs["vllm_device"] = "auto" - if cfg.trl.vllm_gpu_memory_utilization: + if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: grpo_args_kwargs[ "vllm_gpu_memory_utilization" ] = cfg.trl.vllm_gpu_memory_utilization - if cfg.trl.vllm_max_model_len: + if cfg.trl and cfg.trl.vllm_max_model_len: grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len if cfg.trl and cfg.trl.num_generations: grpo_args_kwargs["num_generations"] = cfg.trl.num_generations if cfg.trl and cfg.trl.sync_ref_model: grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model - if cfg.trl.ref_model_mixup_alpha: + if cfg.trl and cfg.trl.ref_model_mixup_alpha: grpo_args_kwargs[ "ref_model_mixup_alpha" ] = cfg.trl.ref_model_mixup_alpha - if cfg.trl.ref_model_sync_steps: + if cfg.trl and cfg.trl.ref_model_sync_steps: grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length return grpo_args_kwargs diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c2c61f478..551e1868e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -671,7 +671,7 @@ class AxolotlInputConfig( rl: Optional[RLType] = None trl: Optional[TrlConfig] = Field( - default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda + default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda ) reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None