prevent empty value for vllm_mode (#2998)
This commit is contained in:
@@ -49,7 +49,8 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
if trl.use_vllm:
|
if trl.use_vllm:
|
||||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
if trl.vllm_mode:
|
||||||
|
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||||
if trl.vllm_mode == "colocate":
|
if trl.vllm_mode == "colocate":
|
||||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||||
vllm_cfg.gpu_memory_utilization
|
vllm_cfg.gpu_memory_utilization
|
||||||
|
|||||||
@@ -1268,6 +1268,19 @@ class DistributedValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOVllmValidationMixin:
|
||||||
|
"""Validation mixin for vllm when using GRPO."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_vllm_mode_set(self):
|
||||||
|
if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:
|
||||||
|
LOG.warning(
|
||||||
|
"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`"
|
||||||
|
)
|
||||||
|
self.trl.vllm_mode = "server"
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
class ValidationMixin(
|
class ValidationMixin(
|
||||||
DatasetValidationMixin,
|
DatasetValidationMixin,
|
||||||
@@ -1281,5 +1294,6 @@ class ValidationMixin(
|
|||||||
PretrainingValidationMixin,
|
PretrainingValidationMixin,
|
||||||
ModelCompatibilityValidationMixin,
|
ModelCompatibilityValidationMixin,
|
||||||
ComplexValidationMixin,
|
ComplexValidationMixin,
|
||||||
|
GRPOVllmValidationMixin,
|
||||||
):
|
):
|
||||||
"""Full validation mixin for Axolotl configuration."""
|
"""Full validation mixin for Axolotl configuration."""
|
||||||
|
|||||||
Reference in New Issue
Block a user