More minor RL fixes (#3551)
* fix: handle get_open_port import across TRL versions TRL 0.29+ removed get_open_port from exports; fall back to importing directly from vllm.utils or vllm.utils.network_utils. * support DP with vllm and make generation_batch_size confifurable
This commit is contained in:
@@ -89,6 +89,9 @@ class GRPOStrategy:
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
if trl.generation_batch_size is not None:
|
||||
grpo_args_kwargs["generation_batch_size"] = trl.generation_batch_size
|
||||
|
||||
if trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
|
||||
|
||||
@@ -26,8 +26,15 @@ from trl.scripts.vllm_serve import (
|
||||
ScriptArguments,
|
||||
chunk_list,
|
||||
extract_logprobs,
|
||||
get_open_port,
|
||||
)
|
||||
|
||||
try:
|
||||
from trl.scripts.vllm_serve import get_open_port
|
||||
except ImportError:
|
||||
try:
|
||||
from vllm.utils import get_open_port
|
||||
except ImportError:
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@@ -63,10 +70,21 @@ def llm_worker(
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
# For DP with TP=1: pin each worker to its own GPU via CUDA_VISIBLE_DEVICES.
|
||||
# vLLM's LLM() offline mode doesn't support DP env vars natively, so we
|
||||
# isolate each worker to a single GPU and let vLLM think it's the only one.
|
||||
if script_args.data_parallel_size > 1 and script_args.tensor_parallel_size == 1:
|
||||
visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
||||
if visible:
|
||||
gpu_ids = visible.split(",")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[data_parallel_rank]
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(data_parallel_rank)
|
||||
else:
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
|
||||
@@ -66,6 +66,14 @@ class TRLConfig(BaseModel):
|
||||
"description": "List of reward weights for the reward functions."
|
||||
},
|
||||
)
|
||||
generation_batch_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Batch size for generation. Controls how many unique "
|
||||
"prompts are generated per step. For full DP utilization, set to "
|
||||
"num_generations * data_parallel_size (or a multiple thereof)."
|
||||
},
|
||||
)
|
||||
num_generations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of generations to sample."},
|
||||
|
||||
Reference in New Issue
Block a user