diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 7eda7a0ba..d1a6b7fd9 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -52,6 +52,7 @@ class GRPOStrategy: if trl.vllm_mode: grpo_args_kwargs["vllm_mode"] = trl.vllm_mode if trl.vllm_mode == "colocate": + grpo_args_kwargs["enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined] grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( vllm_cfg.gpu_memory_utilization ) diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 980474e87..624f7663e 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -167,3 +167,9 @@ class TRLConfig(BaseModel): "description": "Whether to exclude truncated completions from loss calculation." }, ) + vllm_enable_sleep_mode: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable sleep mode for vLLM to offload VRAM when idle" + }, + )