Gradio configuration parameters (#1591)
* Gradio Configuration Settings * Making various Gradio variables configurable instead of hardcoded * Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3 * Fix type of gradio_temperature * revert un-necessary change and lint --------- Co-authored-by: Marijn Stollenga <stollenga@imfusion.de> Co-authored-by: Marijn Stollenga <stollenga@imfusion.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||
temperature=cfg.get("gradio_temperature", 0.9),
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
|
||||
@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
gradio_title: Optional[str] = None
|
||||
gradio_share: Optional[bool] = None
|
||||
gradio_server_name: Optional[str] = None
|
||||
gradio_server_port: Optional[int] = None
|
||||
gradio_max_new_tokens: Optional[int] = None
|
||||
gradio_temperature: Optional[float] = None
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
|
||||
Reference in New Issue
Block a user