Gradio configuration parameters (#1591)
* Gradio Configuration Settings * Making various Gradio variables configurable instead of hardcoded * Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3 * Fix type of gradio_temperature * revert un-necessary change and lint --------- Co-authored-by: Marijn Stollenga <stollenga@imfusion.de> Co-authored-by: Marijn Stollenga <stollenga@imfusion.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=1024,
|
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||||
temperature=0.9,
|
temperature=cfg.get("gradio_temperature", 0.9),
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
|||||||
outputs="text",
|
outputs="text",
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
demo.queue().launch(show_api=False, share=True)
|
|
||||||
|
demo.queue().launch(
|
||||||
|
show_api=False,
|
||||||
|
share=cfg.get("gradio_share", True),
|
||||||
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
|
|||||||
@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: Optional[str] = None
|
||||||
|
gradio_share: Optional[bool] = None
|
||||||
|
gradio_server_name: Optional[str] = None
|
||||||
|
gradio_server_port: Optional[int] = None
|
||||||
|
gradio_max_new_tokens: Optional[int] = None
|
||||||
|
gradio_temperature: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||||
class AxolotlInputConfig(
|
class AxolotlInputConfig(
|
||||||
ModelInputConfig,
|
ModelInputConfig,
|
||||||
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
|
GradioConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
|||||||
Reference in New Issue
Block a user