More minor RL fixes (#3551)

* fix: handle get_open_port import across TRL versions

TRL 0.29+ removed get_open_port from exports; fall back to importing
directly from vllm.utils or vllm.utils.network_utils.

* support DP with vllm and make generation_batch_size confifurable
This commit is contained in:
Wing Lian
2026-03-25 18:17:49 -04:00
committed by GitHub
parent 74b959e035
commit 5191e4eb53
3 changed files with 34 additions and 5 deletions

View File

@@ -89,6 +89,9 @@ class GRPOStrategy:
if trl.num_generations: if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.generation_batch_size is not None:
grpo_args_kwargs["generation_batch_size"] = trl.generation_batch_size
if trl.sync_ref_model: if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model

View File

@@ -26,8 +26,15 @@ from trl.scripts.vllm_serve import (
ScriptArguments, ScriptArguments,
chunk_list, chunk_list,
extract_logprobs, extract_logprobs,
get_open_port,
) )
try:
from trl.scripts.vllm_serve import get_open_port
except ImportError:
try:
from vllm.utils import get_open_port
except ImportError:
from vllm.utils.network_utils import get_open_port
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@@ -63,10 +70,21 @@ def llm_worker(
connection: Connection, connection: Connection,
) -> None: ) -> None:
"""Worker process that creates a vLLM LLM with LoRA enabled.""" """Worker process that creates a vLLM LLM with LoRA enabled."""
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) # For DP with TP=1: pin each worker to its own GPU via CUDA_VISIBLE_DEVICES.
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) # vLLM's LLM() offline mode doesn't support DP env vars natively, so we
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) # isolate each worker to a single GPU and let vLLM think it's the only one.
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) if script_args.data_parallel_size > 1 and script_args.tensor_parallel_size == 1:
visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if visible:
gpu_ids = visible.split(",")
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[data_parallel_rank]
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(data_parallel_rank)
else:
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM( llm = LLM(
model=script_args.model, model=script_args.model,

View File

@@ -66,6 +66,14 @@ class TRLConfig(BaseModel):
"description": "List of reward weights for the reward functions." "description": "List of reward weights for the reward functions."
}, },
) )
generation_batch_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Batch size for generation. Controls how many unique "
"prompts are generated per step. For full DP utilization, set to "
"num_generations * data_parallel_size (or a multiple thereof)."
},
)
num_generations: int | None = Field( num_generations: int | None = Field(
default=None, default=None,
json_schema_extra={"description": "Number of generations to sample."}, json_schema_extra={"description": "Number of generations to sample."},