diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 5202cb09d..ecfc12309 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -9,6 +9,7 @@ import logging from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer +from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig LOG = logging.getLogger("axolotl") @@ -31,30 +32,44 @@ class GRPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): grpo_args_kwargs = {} - if cfg.trl and cfg.trl.use_vllm: - grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm - if cfg.trl and cfg.trl.vllm_device: - grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device - else: - grpo_args_kwargs["vllm_device"] = "auto" - if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: + + if not hasattr(cfg, "trl") or not cfg.trl: + return grpo_args_kwargs + + trl: TRLConfig = cfg.trl # type: ignore + + if trl.use_vllm: + grpo_args_kwargs["use_vllm"] = trl.use_vllm + grpo_args_kwargs["vllm_device"] = ( + trl.vllm_device if trl.vllm_device else "auto" + ) + + if trl.vllm_gpu_memory_utilization: grpo_args_kwargs[ "vllm_gpu_memory_utilization" - ] = cfg.trl.vllm_gpu_memory_utilization - if cfg.trl and cfg.trl.vllm_max_model_len: - grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len - if cfg.trl and cfg.trl.num_generations: - grpo_args_kwargs["num_generations"] = cfg.trl.num_generations - if cfg.trl and cfg.trl.sync_ref_model: - grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model - if cfg.trl and cfg.trl.ref_model_mixup_alpha: - grpo_args_kwargs[ - "ref_model_mixup_alpha" - ] = cfg.trl.ref_model_mixup_alpha - if cfg.trl and cfg.trl.ref_model_sync_steps: - grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps - grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length - grpo_args_kwargs["log_completions"] = cfg.trl.log_completions + ] = trl.vllm_gpu_memory_utilization + + if trl.vllm_max_model_len: + grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len + + if trl.num_generations: + grpo_args_kwargs["num_generations"] = trl.num_generations + + if trl.sync_ref_model: + grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model + + if trl.ref_model_mixup_alpha: + grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha + + if trl.ref_model_sync_steps: + grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps + + grpo_args_kwargs["max_completion_length"] = trl.max_completion_length + grpo_args_kwargs["log_completions"] = trl.log_completions + + if trl.reward_weights: + grpo_args_kwargs["reward_weights"] = trl.reward_weights + return grpo_args_kwargs @classmethod diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/config/models/input/v0_4_1/trl.py index 6361bb249..ae26e634c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/trl.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/trl.py @@ -27,6 +27,7 @@ class TRLConfig(BaseModel): vllm_dtype: Optional[str] = "auto" reward_funcs: Optional[List[str]] = None + reward_weights: Optional[List[float]] = None num_generations: Optional[int] = None log_completions: Optional[bool] = False