Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]
* add grpo scale_rewards config for trl#3135 * options to connect to vllm server directly w grpo trl#3094 * temperature support trl#3029 * sampling/generation kwargs for grpo trl#2989 * make vllm_enable_prefix_caching a config param trl#2900 * grpo multi-step optimizeations trl#2899 * remove overrides for grpo trainer * bump trl to 0.16.0 * add cli to start vllm-serve via trl * call the python module directly * update to use vllm with 2.6.0 too now and call trl vllm serve from module * vllm 0.8.1 * use python3 * use sys.executable * remove context and wait for start * fixes to make it actually work * fixes so the grpo tests pass with new vllm paradigm * explicit host/port and check in start vllm * make sure that vllm doesn't hang by setting quiet so outouts go to dev null * also bump bnb to latest release * add option for wait from cli and nccl debugging for ci * grpo + vllm test on separate devices for now * make sure grpo + vllm tests runs single worker since pynccl comms would conflict * fix cli * remove wait and add caching for argilla dataset * refactoring configs * chore: lint * add vllm config * fixup vllm grpo args * fix one more incorrect schema/config path * fix another vlllm reference and increase timeout * make the tests run a bit faster * change mbsz back so it is correct for grpo * another change mbsz back so it is correct for grpo * fixing cli args * nits * adding docs * docs * include tensor parallel size for vllm in pydantic schema * moving start_vllm, more docs * limit output len for grpo vllm * vllm enable_prefix_caching isn't a bool cli arg * fix env ordering in tests and also use pid check when looking for vllm --------- Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -35,6 +35,55 @@ class TrainerCliArgs:
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: str = field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -14,7 +14,12 @@ import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(VllmServeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
55
src/axolotl/cli/vllm_serve.py
Normal file
55
src/axolotl/cli/vllm_serve.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
|
||||
def do_vllm_serve(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Starts the VLLM server for serving LLM models used for online RL
|
||||
|
||||
Args
|
||||
:param cfg: Parsed doct of the YAML config
|
||||
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||
)
|
||||
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||
enable_prefix_caching = (
|
||||
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||
)
|
||||
|
||||
vllm_script_args = ScriptArguments(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
@@ -40,18 +40,15 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_device"] = (
|
||||
trl.vllm_device if trl.vllm_device else "auto"
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||
trl.vllm_guided_decoding_regex
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
@@ -70,6 +67,25 @@ class GRPOStrategy:
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
if trl.scale_rewards is not None:
|
||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||
|
||||
if trl.temperature is not None:
|
||||
grpo_args_kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
grpo_args_kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
grpo_args_kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
grpo_args_kwargs["min_p"] = trl.min_p
|
||||
if trl.repetition_penalty is not None:
|
||||
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
|
||||
if trl.num_iterations is not None:
|
||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,16 +2,18 @@
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from contextlib import nullcontext
|
||||
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
@@ -19,91 +21,49 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
@@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,6 +87,9 @@ class AxolotlInputConfig(
|
||||
trl: TRLConfig | None = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
|
||||
@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: bool | None = Field(
|
||||
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||
use_vllm: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
vllm_server_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: int | None = Field(
|
||||
vllm_server_timeout: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
scale_rewards: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||
},
|
||||
)
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Top-p sampling probability for the generation policy."
|
||||
},
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Top-k sampling for the generation policy."},
|
||||
)
|
||||
min_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum probability for the generation policy."
|
||||
},
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||
},
|
||||
)
|
||||
epsilon: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
|
||||
38
src/axolotl/utils/schemas/vllm.py
Normal file
38
src/axolotl/utils/schemas/vllm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VllmConfig(BaseModel):
|
||||
"""
|
||||
Configuration for VLLM server
|
||||
"""
|
||||
|
||||
device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||
)
|
||||
gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||
)
|
||||
Reference in New Issue
Block a user