More minor RL fixes (#3551)
* fix: handle get_open_port import across TRL versions TRL 0.29+ removed get_open_port from exports; fall back to importing directly from vllm.utils or vllm.utils.network_utils. * support DP with vllm and make generation_batch_size confifurable
This commit is contained in:
@@ -89,6 +89,9 @@ class GRPOStrategy:
|
|||||||
if trl.num_generations:
|
if trl.num_generations:
|
||||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||||
|
|
||||||
|
if trl.generation_batch_size is not None:
|
||||||
|
grpo_args_kwargs["generation_batch_size"] = trl.generation_batch_size
|
||||||
|
|
||||||
if trl.sync_ref_model:
|
if trl.sync_ref_model:
|
||||||
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,15 @@ from trl.scripts.vllm_serve import (
|
|||||||
ScriptArguments,
|
ScriptArguments,
|
||||||
chunk_list,
|
chunk_list,
|
||||||
extract_logprobs,
|
extract_logprobs,
|
||||||
get_open_port,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from trl.scripts.vllm_serve import get_open_port
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from vllm.utils import get_open_port
|
||||||
|
except ImportError:
|
||||||
|
from vllm.utils.network_utils import get_open_port
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
@@ -63,10 +70,21 @@ def llm_worker(
|
|||||||
connection: Connection,
|
connection: Connection,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
||||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
# For DP with TP=1: pin each worker to its own GPU via CUDA_VISIBLE_DEVICES.
|
||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
# vLLM's LLM() offline mode doesn't support DP env vars natively, so we
|
||||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
# isolate each worker to a single GPU and let vLLM think it's the only one.
|
||||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
if script_args.data_parallel_size > 1 and script_args.tensor_parallel_size == 1:
|
||||||
|
visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
if visible:
|
||||||
|
gpu_ids = visible.split(",")
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[data_parallel_rank]
|
||||||
|
else:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(data_parallel_rank)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||||
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||||
|
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||||
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=script_args.model,
|
model=script_args.model,
|
||||||
|
|||||||
@@ -66,6 +66,14 @@ class TRLConfig(BaseModel):
|
|||||||
"description": "List of reward weights for the reward functions."
|
"description": "List of reward weights for the reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
generation_batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Batch size for generation. Controls how many unique "
|
||||||
|
"prompts are generated per step. For full DP utilization, set to "
|
||||||
|
"num_generations * data_parallel_size (or a multiple thereof)."
|
||||||
|
},
|
||||||
|
)
|
||||||
num_generations: int | None = Field(
|
num_generations: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Number of generations to sample."},
|
json_schema_extra={"description": "Number of generations to sample."},
|
||||||
|
|||||||
Reference in New Issue
Block a user