diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 839c20c2e..4106a2a7d 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -49,7 +49,8 @@ class GRPOStrategy: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - grpo_args_kwargs["vllm_mode"] = trl.vllm_mode + if trl.vllm_mode: + grpo_args_kwargs["vllm_mode"] = trl.vllm_mode if trl.vllm_mode == "colocate": grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( vllm_cfg.gpu_memory_utilization diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 502c18e7d..aa249c6ce 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1268,6 +1268,19 @@ class DistributedValidationMixin: return self +class GRPOVllmValidationMixin: + """Validation mixin for vllm when using GRPO.""" + + @model_validator(mode="after") + def check_vllm_mode_set(self): + if self.trl and self.trl.use_vllm and not self.trl.vllm_mode: + LOG.warning( + "vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`" + ) + self.trl.vllm_mode = "server" + return self + + # pylint: disable=too-many-ancestors class ValidationMixin( DatasetValidationMixin, @@ -1281,5 +1294,6 @@ class ValidationMixin( PretrainingValidationMixin, ModelCompatibilityValidationMixin, ComplexValidationMixin, + GRPOVllmValidationMixin, ): """Full validation mixin for Axolotl configuration."""