prevent empty value for vllm_mode (#2998)
This commit is contained in:
@@ -49,7 +49,8 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
|
||||
@@ -1268,6 +1268,19 @@ class DistributedValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
class GRPOVllmValidationMixin:
|
||||
"""Validation mixin for vllm when using GRPO."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_vllm_mode_set(self):
|
||||
if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:
|
||||
LOG.warning(
|
||||
"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`"
|
||||
)
|
||||
self.trl.vllm_mode = "server"
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class ValidationMixin(
|
||||
DatasetValidationMixin,
|
||||
@@ -1281,5 +1294,6 @@ class ValidationMixin(
|
||||
PretrainingValidationMixin,
|
||||
ModelCompatibilityValidationMixin,
|
||||
ComplexValidationMixin,
|
||||
GRPOVllmValidationMixin,
|
||||
):
|
||||
"""Full validation mixin for Axolotl configuration."""
|
||||
|
||||
Reference in New Issue
Block a user