Gradio configuration parameters (#1591)

* Gradio Configuration Settings

* Making various Gradio variables configurable instead of hardcoded

* Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3

* Fix type of gradio_temperature

* revert un-necessary change and lint

---------

Co-authored-by: Marijn Stollenga <stollenga@imfusion.de>
Co-authored-by: Marijn Stollenga <stollenga@imfusion.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
marijnfs
2024-05-06 21:43:42 +02:00
committed by GitHub
parent 1ac899800b
commit 3367fca732
2 changed files with 21 additions and 3 deletions

View File

@@ -264,8 +264,8 @@ def do_inference_gradio(
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
@@ -300,7 +300,13 @@ def do_inference_gradio(
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def choose_config(path: Path):

View File

@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
return data
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
gradio_title: Optional[str] = None
gradio_share: Optional[bool] = None
gradio_server_name: Optional[str] = None
gradio_server_port: Optional[int] = None
gradio_max_new_tokens: Optional[int] = None
gradio_temperature: Optional[float] = None
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,