diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 5c057cc40..4a8c0b81d 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -89,6 +89,9 @@ class GRPOStrategy: if trl.num_generations: grpo_args_kwargs["num_generations"] = trl.num_generations + if trl.generation_batch_size is not None: + grpo_args_kwargs["generation_batch_size"] = trl.generation_batch_size + if trl.sync_ref_model: grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index 2dda0f9bf..9ca8a9134 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -26,8 +26,15 @@ from trl.scripts.vllm_serve import ( ScriptArguments, chunk_list, extract_logprobs, - get_open_port, ) + +try: + from trl.scripts.vllm_serve import get_open_port +except ImportError: + try: + from vllm.utils import get_open_port + except ImportError: + from vllm.utils.network_utils import get_open_port from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -63,10 +70,21 @@ def llm_worker( connection: Connection, ) -> None: """Worker process that creates a vLLM LLM with LoRA enabled.""" - os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) - os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) - os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) - os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + # For DP with TP=1: pin each worker to its own GPU via CUDA_VISIBLE_DEVICES. + # vLLM's LLM() offline mode doesn't support DP env vars natively, so we + # isolate each worker to a single GPU and let vLLM think it's the only one. + if script_args.data_parallel_size > 1 and script_args.tensor_parallel_size == 1: + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if visible: + gpu_ids = visible.split(",") + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[data_parallel_rank] + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(data_parallel_rank) + else: + os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) + os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) llm = LLM( model=script_args.model, diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 4ef42db66..cd6a9c57a 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -66,6 +66,14 @@ class TRLConfig(BaseModel): "description": "List of reward weights for the reward functions." }, ) + generation_batch_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Batch size for generation. Controls how many unique " + "prompts are generated per step. For full DP utilization, set to " + "num_generations * data_parallel_size (or a multiple thereof)." + }, + ) num_generations: int | None = Field( default=None, json_schema_extra={"description": "Number of generations to sample."},