prevent empty value for vllm_mode (#2998)

This commit is contained in:
Wing Lian
2025-08-01 09:59:45 -04:00
committed by GitHub
parent 7026cd5e9e
commit 02a37199ee
2 changed files with 16 additions and 1 deletions

View File

@@ -49,7 +49,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization

View File

@@ -1268,6 +1268,19 @@ class DistributedValidationMixin:
return self
class GRPOVllmValidationMixin:
"""Validation mixin for vllm when using GRPO."""
@model_validator(mode="after")
def check_vllm_mode_set(self):
if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:
LOG.warning(
"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`"
)
self.trl.vllm_mode = "server"
return self
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
@@ -1281,5 +1294,6 @@ class ValidationMixin(
PretrainingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,
):
"""Full validation mixin for Axolotl configuration."""