diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 4934bf9ca..dfa315618 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -42,8 +42,7 @@ jobs: cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.6.0 - # awaiting vllm#12721 - axolotl_extras: + axolotl_extras: vllm num_gpus: 2 nightly_build: "true" runs-on: [self-hosted, modal] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad6305e8f..a1a7214ec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -256,7 +256,7 @@ jobs: python_version: "3.11" pytorch: 2.6.0 num_gpus: 1 - axolotl_extras: + axolotl_extras: vllm steps: - name: Checkout uses: actions/checkout@v4 diff --git a/_quarto.yml b/_quarto.yml index 804fc5e84..cf85da473 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -40,6 +40,7 @@ quartodoc: - cli.preprocess - cli.sweeps - cli.utils + - cli.vllm_serve - cli.cloud.base - cli.cloud.modal_ - title: Trainers diff --git a/cicd/multigpu.sh b/cicd/multigpu.sh index 05d1bbbf2..84dfc6f71 100755 --- a/cicd/multigpu.sh +++ b/cicd/multigpu.sh @@ -2,4 +2,5 @@ set -e # only run one test at a time so as not to OOM the GPU -pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ +pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ +pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ diff --git a/docs/config.qmd b/docs/config.qmd index 753cf47e1..b0c8616a2 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss # grpo trl: use_vllm: # Optional[bool]. Whether to use VLLM for RL training. - vllm_device: # Optional[str]. Device to use for VLLM. - vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM. - vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM. - vllm_dtype: # Optional[str]. Data type for VLLM. + vllm_server_host: # Optional[str]. Host of the vLLM server to connect to. + vllm_server_port: # Optional[int]. Port of the vLLM server to connect to. + vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond. + vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding. beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use max_completion_length: # Optional[int]. Maximum length of the completion for RL training. diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 6bef7c831..b3adb5937 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo). ::: +If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training. +First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're +using 4 GPUs - 2 for training, and 2 for vLLM: + +::: {.callout-important} +Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`. +::: + +```yaml +base_model: Qwen/Qwen2.5-1.5B-Instruct + +vllm: + host: 0.0.0.0 + port: 8000 + tensor_parallel_size: 2 + gpu_memory_utilization: 0.85 + dtype: auto + # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand + +rl: grpo +trl: + use_vllm: true + vllm_server_host: 0.0.0.0 + vllm_server_port: 8000 + vllm_server_timeout: 300 +``` + +```bash +CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml +``` + +Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute: + +```bash +CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2 +``` + +#### Reward functions + GRPO uses custom reward functions and transformations. Please have them ready locally. -For ex, to load OpenAI's GSM8K and use a random reward for completions: +For example, to load OpenAI's GSM8K and use a random reward for completions: ```python # rewards.py @@ -530,8 +569,6 @@ trl: beta: 0.001 max_completion_length: 256 use_vllm: True - vllm_device: auto - vllm_gpu_memory_utilization: 0.15 num_generations: 4 reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_weights: [1.0] diff --git a/requirements.txt b/requirements.txt index 096237f19..9aff0ccfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ # START section of dependencies that don't install on Darwin/MacOS -bitsandbytes==0.45.3 +bitsandbytes==0.45.4 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 @@ -17,7 +17,7 @@ tokenizers>=0.21.1 accelerate==1.5.2 datasets==3.5.0 deepspeed==0.16.4 -trl==0.15.1 +trl==0.16.0 optimum==1.16.2 hf_transfer diff --git a/setup.py b/setup.py index 4c3024b85..d19c14828 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ from pathlib import Path from setuptools import find_packages, setup -def parse_requirements(): +def parse_requirements(extras_require_map): _install_requires = [] _dependency_links = [] with open("./requirements.txt", encoding="utf-8") as requirements_file: @@ -67,6 +67,7 @@ def parse_requirements(): if (major, minor) >= (2, 6): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers==0.0.29.post2") + extras_require_map["vllm"] = ["vllm==0.8.1"] elif (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: @@ -86,7 +87,7 @@ def parse_requirements(): except PackageNotFoundError: pass - return _install_requires, _dependency_links + return _install_requires, _dependency_links, extras_require_map def get_package_version(): @@ -103,7 +104,46 @@ def get_package_version(): return version_ -install_requires, dependency_links = parse_requirements() +extras_require = { + "flash-attn": ["flash-attn==2.7.4.post1"], + "ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], + "deepspeed": [ + "deepspeed==0.16.4", + "deepspeed-kernels", + ], + "mamba-ssm": [ + "mamba-ssm==1.2.0.post1", + "causal_conv1d", + ], + "auto-gptq": [ + "auto-gptq==0.5.1", + ], + "mlflow": [ + "mlflow", + ], + "galore": [ + "galore_torch", + ], + "apollo": [ + "apollo-torch", + ], + "optimizers": [ + "galore_torch", + "apollo-torch", + "lomo-optim==0.1.1", + "torch-optimi==0.2.1", + ], + "ray": [ + "ray[train]", + ], + "vllm": [ + "vllm==0.7.2", + ], +} + +install_requires, dependency_links, extras_require_build = parse_requirements( + extras_require +) setup( version=get_package_version(), @@ -116,40 +156,5 @@ setup( "axolotl=axolotl.cli.main:main", ], }, - extras_require={ - "flash-attn": ["flash-attn==2.7.4.post1"], - "ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], - "deepspeed": [ - "deepspeed==0.16.4", - "deepspeed-kernels", - ], - "mamba-ssm": [ - "mamba-ssm==1.2.0.post1", - "causal_conv1d", - ], - "auto-gptq": [ - "auto-gptq==0.5.1", - ], - "mlflow": [ - "mlflow", - ], - "galore": [ - "galore_torch", - ], - "apollo": [ - "apollo-torch", - ], - "optimizers": [ - "galore_torch", - "apollo-torch", - "lomo-optim==0.1.1", - "torch-optimi==0.2.1", - ], - "ray": [ - "ray[train]", - ], - "vllm": [ - "vllm==0.7.2", - ], - }, + extras_require=extras_require_build, ) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index a39ffc308..72e61d1bb 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -35,6 +35,55 @@ class TrainerCliArgs: num_processes: Optional[int] = field(default=None) +@dataclass +class VllmServeCliArgs: + """Dataclass with CLI arguments for `axolotl vllm-serve` command.""" + + tensor_parallel_size: int = field( + default=1, + metadata={"help": "Number of tensor parallel workers to use."}, + ) + host: str = field( + default="0.0.0.0", # nosec B104 + metadata={"help": "Host address to run the server on."}, + ) + port: int = field( + default=8000, + metadata={"help": "Port to run the server on."}, + ) + gpu_memory_utilization: Optional[float] = field( + default=None, + metadata={ + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." + }, + ) + dtype: Optional[str] = field( + default=None, + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + }, + ) + max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " + "context size, which might be much larger than the KV cache, leading to inefficiencies." + }, + ) + enable_prefix_caching: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " + "hardware support this feature." + }, + ) + + @dataclass class EvaluateCliArgs: """Dataclass with CLI arguments for `axolotl evaluate` command.""" diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index f53ea825a..7532a9689 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -14,7 +14,12 @@ import yaml from dotenv import load_dotenv import axolotl -from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.cli.args import ( + EvaluateCliArgs, + PreprocessCliArgs, + TrainerCliArgs, + VllmServeCliArgs, +) from axolotl.cli.sweeps import generate_sweep_configs from axolotl.cli.utils import ( add_options_from_config, @@ -23,6 +28,7 @@ from axolotl.cli.utils import ( fetch_from_github, filter_none_kwargs, ) +from axolotl.cli.vllm_serve import do_vllm_serve from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.schemas.config import AxolotlInputConfig @@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None: fetch_from_github(f"{directory}/", dest) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(VllmServeCliArgs) +@filter_none_kwargs +def vllm_serve(config: str, **cli_args: VllmServeCliArgs): + do_vllm_serve(config, cli_args) + + cli.add_command(lm_eval) diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py new file mode 100644 index 000000000..552f33e9e --- /dev/null +++ b/src/axolotl/cli/vllm_serve.py @@ -0,0 +1,55 @@ +""" +CLI to start the vllm server for online RL +""" + +from pathlib import Path +from typing import Union + +from trl.scripts.vllm_serve import ScriptArguments +from trl.scripts.vllm_serve import main as vllm_serve_main + +from axolotl.cli.config import load_cfg + + +def do_vllm_serve( + config: Union[Path, str], + cli_args: dict, +): + """ + Starts the VLLM server for serving LLM models used for online RL + + Args + :param cfg: Parsed doct of the YAML config + :param cli_args: dict of additional command-line arguments of type VllmServeCliArgs + + Returns: + process_id: the process id of the started VLLM server + """ + cfg = load_cfg(config) + model = cfg.base_model + + tensor_parallel_size = ( + cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size + ) + host = cli_args.get("host") or cfg.vllm.host + port = cli_args.get("port") or cfg.vllm.port + gpu_memory_utilization = ( + cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization + ) + dtype = cli_args.get("dtype") or cfg.vllm.dtype + max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len + enable_prefix_caching = ( + cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching + ) + + vllm_script_args = ScriptArguments( + model, + tensor_parallel_size=tensor_parallel_size, + host=host, + port=port, + gpu_memory_utilization=gpu_memory_utilization, + dtype=dtype, + max_model_len=max_model_len, + enable_prefix_caching=enable_prefix_caching, + ) + vllm_serve_main(vllm_script_args) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index f0c42830d..219eced69 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -40,18 +40,15 @@ class GRPOStrategy: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - grpo_args_kwargs["vllm_device"] = ( - trl.vllm_device if trl.vllm_device else "auto" - ) - - if trl.vllm_gpu_memory_utilization: - grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( - trl.vllm_gpu_memory_utilization + grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host + grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port + if trl.vllm_server_timeout: + grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout + if trl.vllm_guided_decoding_regex: + grpo_args_kwargs["vllm_guided_decoding_regex"] = ( + trl.vllm_guided_decoding_regex ) - if trl.vllm_max_model_len: - grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len - if trl.num_generations: grpo_args_kwargs["num_generations"] = trl.num_generations @@ -70,6 +67,25 @@ class GRPOStrategy: if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights + if trl.scale_rewards is not None: + grpo_args_kwargs["scale_rewards"] = trl.scale_rewards + + if trl.temperature is not None: + grpo_args_kwargs["temperature"] = trl.temperature + if trl.top_p is not None: + grpo_args_kwargs["top_p"] = trl.top_p + if trl.top_k is not None: + grpo_args_kwargs["top_k"] = trl.top_k + if trl.min_p is not None: + grpo_args_kwargs["min_p"] = trl.min_p + if trl.repetition_penalty is not None: + grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty + + if trl.num_iterations is not None: + grpo_args_kwargs["num_iterations"] = trl.num_iterations + if trl.epsilon is not None: + grpo_args_kwargs["epsilon"] = trl.epsilon + return grpo_args_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 663bed094..e8a142945 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -2,16 +2,18 @@ Axolotl GRPO trainer """ -from accelerate.utils import is_peft_model -from accelerate.utils.other import is_compiled_module -from transformers import PreTrainedModel -from trl import GRPOConfig, GRPOTrainer -from trl.models import unwrap_model_for_generation +from contextlib import nullcontext + +from accelerate.utils import is_deepspeed_available, is_peft_model +from trl import GRPOTrainer +from trl.extras.profiling import profiling_decorator from axolotl.core.trainers.base import SchedulerMixin +if is_deepspeed_available(): + import deepspeed + -# mypy: ignore-errors class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): """ Extend the base GRPOTrainer for axolotl helpers @@ -19,91 +21,49 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): _tag_names = ["trl", "grpo", "axolotl"] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # pylint: disable=access-member-before-definition - # Enable gradient checkpointing if requested - if kwargs["args"].gradient_checkpointing: - # Ensure use_cache is disabled - if hasattr(self.model, "config"): - self.model.config.use_cache = False - - # Enable gradient checkpointing on the base model for PEFT - if is_peft_model(self.model) and hasattr( - self.model.base_model, "gradient_checkpointing_enable" - ): - self.model.base_model.gradient_checkpointing_enable() - # Enable gradient checkpointing for non-PEFT models - elif hasattr(self.model, "gradient_checkpointing_enable"): - self.model.gradient_checkpointing_enable() - self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) - # pylint: enable=access-member-before-definition - - def _enable_gradient_checkpointing( - self, model: PreTrainedModel, args: GRPOConfig - ) -> PreTrainedModel: - """Enables gradient checkpointing for the model.""" - # pylint: disable=unused-argument,redefined-builtin - gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} - use_reentrant = ( - "use_reentrant" not in gradient_checkpointing_kwargs - or gradient_checkpointing_kwargs["use_reentrant"] + @profiling_decorator + def _move_model_to_vllm(self): + # For DeepSpeed ZeRO-3, we need to gather all parameters before operations + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + gather_if_zero3 = ( + deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext ) - if use_reentrant: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: + if is_peft_model(self.model): + # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging + # adapters in a sharded manner is not supported. + with gather_if_zero3(list(self.model.parameters())): + self.model.merge_adapter() - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) + # Update vLLM weights while parameters are gathered + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name and discard some parameters + name = ( + name.removeprefix("base_model.model.") + .removeprefix("base_model.model.") + .replace(".base_layer", "") + ) + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = name.replace("modules_to_save.default.", "") - model.get_input_embeddings().register_forward_hook( - make_inputs_require_grad - ) + if self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) - return model - # pylint: enable=unused-argument,redefined-builtin + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather and update each parameter individually. + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + if self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) - def _move_model_to_vllm(self): - with unwrap_model_for_generation( - self.model, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model: - if is_compiled_module(unwrapped_model): - unwrapped_model = ( - unwrapped_model._orig_mod # pylint: disable=protected-access - ) - if is_peft_model(unwrapped_model): - unwrapped_model.merge_adapter() - state_dict = unwrapped_model.state_dict() - # Remove base_model and base_layer prefixes - state_dict = { - k.removeprefix("base_model.model.") - .removeprefix("base_model.model.") - .replace(".base_layer", ""): v - for k, v in state_dict.items() - } - # Remove values with adapter prefix (example: "_lora") - state_dict = { - k: v - for k, v in state_dict.items() - if unwrapped_model.prefix not in k - } - # When module to save, remove its prefix and discard the original module - state_dict = { - k.replace("modules_to_save.default.", ""): v - for k, v in state_dict.items() - if "original_module" not in k - } - else: - state_dict = unwrapped_model.state_dict() - if self.accelerator.is_main_process: - llm_model = ( - self.llm.llm_engine.model_executor.driver_worker.model_runner.model - ) - llm_model.load_weights(state_dict.items()) - if is_peft_model(unwrapped_model): - unwrapped_model.unmerge_adapter() + # Reset cache on main process + if self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 51c5cf08e..c7be33ab3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig +from axolotl.utils.schemas.vllm import VllmConfig LOG = logging.getLogger(__name__) @@ -86,6 +87,9 @@ class AxolotlInputConfig( trl: TRLConfig | None = Field( default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda ) + vllm: VllmConfig | None = Field( + default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda + ) reward_model: bool | None = None process_reward_model: bool | None = None num_labels: int | None = None diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 60759769d..a051fb0ab 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -20,27 +20,30 @@ class TRLConfig(BaseModel): ) # GRPO specific args - # Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22 - use_vllm: bool | None = Field( + # Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23 + use_vllm: bool = Field( default=False, json_schema_extra={"description": "Whether to use VLLM for RL training"}, ) - vllm_device: str | None = Field( - default="auto", - json_schema_extra={"description": "Device to use for VLLM"}, + vllm_server_host: str | None = Field( + default="0.0.0.0", # nosec B104 + json_schema_extra={"description": "Host of the vLLM server to connect to"}, ) - vllm_gpu_memory_utilization: float | None = Field( - default=0.9, - json_schema_extra={"description": "GPU memory utilization for VLLM"}, + vllm_server_port: int | None = Field( + default=8000, + json_schema_extra={"description": "Port of the vLLM server to connect to"}, ) - vllm_dtype: str | None = Field( - default="auto", - json_schema_extra={"description": "Data type for VLLM"}, - ) - vllm_max_model_len: int | None = Field( + vllm_server_timeout: int | None = Field( default=None, json_schema_extra={ - "description": "Maximum length of the model context for VLLM" + "description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + vllm_guided_decoding_regex: str | None = Field( + default=None, + json_schema_extra={ + "description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled." }, ) @@ -85,3 +88,48 @@ class TRLConfig(BaseModel): "description": "Sync steps for the reference model. Requires `sync_ref_model=True`." }, ) + scale_rewards: bool = Field( + default=True, + json_schema_extra={ + "description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation." + }, + ) + + temperature: float | None = Field( + default=None, + json_schema_extra={"description": "Sampling temperature for the GRPO policy."}, + ) + top_p: float | None = Field( + default=None, + json_schema_extra={ + "description": "Top-p sampling probability for the generation policy." + }, + ) + top_k: int | None = Field( + default=None, + json_schema_extra={"description": "Top-k sampling for the generation policy."}, + ) + min_p: float | None = Field( + default=None, + json_schema_extra={ + "description": "Minimum probability for the generation policy." + }, + ) + repetition_penalty: float | None = Field( + default=None, + json_schema_extra={ + "description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far." + }, + ) + num_iterations: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO." + }, + ) + epsilon: float | None = Field( + default=None, + json_schema_extra={ + "description": "Epsilon value for clipping in the GRPO algorithm." + }, + ) diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py new file mode 100644 index 000000000..bb1a4ba26 --- /dev/null +++ b/src/axolotl/utils/schemas/vllm.py @@ -0,0 +1,38 @@ +""" +Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo +""" + +from pydantic import BaseModel, Field + + +class VllmConfig(BaseModel): + """ + Configuration for VLLM server + """ + + device: str | None = Field( + default="auto", + json_schema_extra={"description": "Device to use for VLLM"}, + ) + tensor_parallel_size: int | None = Field( + default=None, + json_schema_extra={"description": "Tensor parallel size for VLLM"}, + ) + gpu_memory_utilization: float | None = Field( + default=0.9, + json_schema_extra={"description": "GPU memory utilization for VLLM"}, + ) + dtype: str | None = Field( + default="auto", + json_schema_extra={"description": "Data type for VLLM"}, + ) + max_model_len: int | None = Field( + default=None, + json_schema_extra={ + "description": "Maximum length of the model context for VLLM" + }, + ) + enable_prefix_caching: bool | None = Field( + default=None, + json_schema_extra={"description": "Enable prefix caching for VLLM"}, + ) diff --git a/tests/conftest.py b/tests/conftest.py index aa867ecb9..b86b714af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,6 +100,14 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset(): ) +@pytest.fixture(scope="session", autouse=True) +def download_argilla_distilabel_intel_orca_dpo_dataset(): + # download the dataset + snapshot_download_w_retry( + "argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset" + ) + + @pytest.fixture(scope="session", autouse=True) def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): # download the dataset diff --git a/tests/e2e/multigpu/solo/__init__.py b/tests/e2e/multigpu/solo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py new file mode 100644 index 000000000..bd999e2f3 --- /dev/null +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -0,0 +1,294 @@ +""" +GRPO test suite +""" + +import os +import random +import subprocess # nosec B404 +import sys +import time +from pathlib import Path + +import pytest +import requests +import yaml +from accelerate.test_utils import execute_subprocess_async +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import require_vllm + + +def start_vllm( + model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs +) -> int: + """ + helper function to start the VLLM server in the background, mostly for testing purposes + """ + cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model] + + if tensor_parallel_size := kwargs.get("tensor_parallel_size"): + cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)]) + if host := kwargs.get("host"): + cmd.extend(["--host", host]) + if port := kwargs.get("port"): + cmd.extend(["--port", str(port)]) + if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"): + cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)]) + if dtype := kwargs.get("dtype"): + cmd.extend(["--dtype", dtype]) + if max_model_len := kwargs.get("max_model_len"): + cmd.extend(["--max-model-len", str(max_model_len)]) + if kwargs.get("enable_prefix_caching"): + cmd.extend(["--enable-prefix-caching", "True"]) + + # print out the command to be executed + print(" ".join(cmd)) + + # start `trl vllm-serve` command in the background and capture the process id + process = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + env=env, + stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, + stderr=subprocess.DEVNULL if quiet else subprocess.PIPE, + ) # nosec B603 + + # print out the process id so the user can easily kill it later + print(f"VLLM server process started (PID: {process.pid})") + + # wait until the http server is ready, even if it 404s, but timeout after 60 seconds + started = False + if wait and host and port: + for _ in range(int(wait)): + try: + response = requests.get(f"http://{host}:{port}", timeout=1) + if int(response.status_code) in [200, 404]: + started = True + break + except requests.exceptions.RequestException: + pass + + # also check if the process.pid is still running + if not process.poll() is None: + break + + time.sleep(1) + + if wait and not started: + print( + f"VLLM server process did not start within {wait} seconds. Please check your server logs." + ) + process.kill() + raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") + + # return the process id + return process.pid + + +class TestGRPO: + """ + Test case for GRPO training using multilpe GPUs + """ + + def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout: + fout.write( + """import random +def rand_reward_func(completions, **kwargs) -> list[float]: + return [random.uniform(0, 1) for _ in completions] + +def oai_gsm8k_transform(cfg, *args, **kwargs): + def transform_fn(example, tokenizer=None): + label = example["answer"].split("####")[-1].strip().replace(",", "") + return { + "prompt": [{"role": "user", "content": example["question"]},], + "answer": label, + } + return transform_fn, {"remove_columns": ["question"]} +""" + ) + + @pytest.mark.parametrize( + "num_gpus", + [1, 2], + ) + @require_vllm + def test_llama_dora(self, temp_dir, num_gpus): + rnd_reward_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "grpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "peft_use_dora": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process_id = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=120, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_gpus), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env}, + ) + finally: + os.kill(vllm_process_id, 9) + + @pytest.mark.parametrize( + "num_gpus", + [1, 2], + ) + @require_vllm + def test_llama_fft(self, temp_dir, num_gpus): + rnd_reward_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "grpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", + }, + ], + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process_id = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=120, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_gpus), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env}, + ) + finally: + os.kill(vllm_process_id, 9) diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index 586da8577..4989b81df 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -52,9 +52,9 @@ class TestMultiGPUEval: }, ], "num_epochs": 1, - "max_steps": 5, + "max_steps": 2, "micro_batch_size": 2, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_8bit", @@ -121,9 +121,9 @@ class TestMultiGPUEval: }, ], "num_epochs": 1, - "max_steps": 5, + "max_steps": 2, "micro_batch_size": 2, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_8bit", diff --git a/tests/e2e/multigpu/test_grpo.py b/tests/e2e/multigpu/test_grpo.py deleted file mode 100644 index a879a7750..000000000 --- a/tests/e2e/multigpu/test_grpo.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -GRPO test suite -""" - -import random -from pathlib import Path - -import pytest -import yaml -from accelerate.test_utils import execute_subprocess_async -from transformers.testing_utils import get_torch_dist_unique_port - -from axolotl.utils.dict import DictDefault - -from tests.e2e.utils import require_vllm - - -class TestGRPO: - """ - Test case for GRPO training using multilpe GPUs - """ - - def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): - # write cfg to yaml file - Path(temp_dir).mkdir(parents=True, exist_ok=True) - with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: - fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) - with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout: - fout.write( - """import random -def rand_reward_func(completions, **kwargs) -> list[float]: - return [random.uniform(0, 1) for _ in completions] - -def oai_gsm8k_transform(cfg, *args, **kwargs): - def transform_fn(example, tokenizer=None): - label = example["answer"].split("####")[-1].strip().replace(",", "") - return { - "prompt": [{"role": "user", "content": example["question"]},], - "answer": label, - } - return transform_fn, {"remove_columns": ["question"]} -""" - ) - - @pytest.mark.parametrize( - "num_gpus", - [1, 2], - ) - @require_vllm - def test_llama_dora(self, temp_dir, num_gpus): - rnd_reward_suffix = str(random.randint(1000, 9999)) - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "chat_template": "llama3", - "rl": "grpo", - "trl": { - "beta": 0.001, - "max_completion_length": 256, - "use_vllm": True, - "vllm_device": "auto" if num_gpus == 1 else "cuda:1", - "vllm_gpu_memory_utilization": 0.15, - "num_generations": 4, - "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], - }, - "datasets": [ - { - "path": "openai/gsm8k", - "name": "main", - "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", - }, - ], - "adapter": "lora", - "lora_r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "lora_target_linear": True, - "peft_use_dora": True, - "flash_attention": True, - "sequence_len": 1024, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "max_steps": 5, - "num_epochs": 1, - "micro_batch_size": 4, - "gradient_accumulation_steps": 2, - "warmup_steps": 10, - "val_set_size": 0.0, - "output_dir": temp_dir, - "learning_rate": 0.0001, - "optimizer": "adamw_torch_fused", - "lr_scheduler": "cosine", - "save_safetensors": True, - "bf16": "auto", - "use_tensorboard": True, - } - ) - - self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) - - execute_subprocess_async( - [ - "axolotl", - "train", - str(Path(temp_dir) / "config.yaml"), - "--num-processes", - str(num_gpus), - "--main-process-port", - f"{get_torch_dist_unique_port()}", - ] - ) - - @pytest.mark.parametrize( - "num_gpus", - [1, 2], - ) - @require_vllm - def test_llama_fft(self, temp_dir, num_gpus): - rnd_reward_suffix = str(random.randint(1000, 9999)) - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "chat_template": "llama3", - "rl": "grpo", - "trl": { - "beta": 0.001, - "max_completion_length": 256, - "use_vllm": True, - "vllm_device": "auto" if num_gpus == 1 else "cuda:1", - "vllm_gpu_memory_utilization": 0.15, - "num_generations": 4, - "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], - }, - "datasets": [ - { - "path": "openai/gsm8k", - "name": "main", - "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", - }, - ], - "flash_attention": True, - "sequence_len": 1024, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "max_steps": 5, - "num_epochs": 1, - "micro_batch_size": 4, - "gradient_accumulation_steps": 2, - "warmup_steps": 10, - "val_set_size": 0.0, - "output_dir": temp_dir, - "learning_rate": 0.0001, - "optimizer": "adamw_torch_fused", - "lr_scheduler": "cosine", - "save_safetensors": True, - "bf16": "auto", - "use_tensorboard": True, - } - ) - - self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) - - execute_subprocess_async( - [ - "axolotl", - "train", - str(Path(temp_dir) / "config.yaml"), - "--num-processes", - str(num_gpus), - "--main-process-port", - f"{get_torch_dist_unique_port()}", - ] - ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 8a16ff096..432d89b1f 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -399,7 +399,7 @@ class TestMultiGPULlama: "num_epochs": 1, "max_steps": 2, "micro_batch_size": 4, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", @@ -478,7 +478,7 @@ class TestMultiGPULlama: "num_epochs": 1, "max_steps": 2, "micro_batch_size": 4, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", @@ -778,7 +778,7 @@ class TestMultiGPULlama: }, ], "num_epochs": 1, - "max_steps": 5, + "max_steps": 2, "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py index 1895e1ee8..af39c6361 100644 --- a/tests/e2e/multigpu/test_qwen2.py +++ b/tests/e2e/multigpu/test_qwen2.py @@ -46,7 +46,7 @@ class TestMultiGPUQwen2: }, ], "num_epochs": 1, - "max_steps": 5, + "max_steps": 2, "warmup_steps": 20, "micro_batch_size": 2, "gradient_accumulation_steps": 2, diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 8e7916728..14b1c0a86 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -50,7 +50,7 @@ class TestMultiGPURay: "num_epochs": 1, "max_steps": 2, "micro_batch_size": 4, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_8bit",