create config

This commit is contained in:
Mads Henrichsen
2024-02-08 09:26:58 +01:00
parent a5724ef08d
commit ddb60883f5

View File

@@ -1,7 +1,6 @@
""" """
This module is used to launch Axolotl with user defined configurations. This module is used to launch Axolotl with user defined configurations.
""" """
import subprocess
import gradio as gr import gradio as gr
import yaml import yaml
@@ -39,23 +38,6 @@ def config(
return yaml.dump(config_dict) return yaml.dump(config_dict)
def create_training_job():
# Start a long-running process
process = subprocess.Popen(
["accelerate launch -m axolotl.cli.train config.yml"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Read the output line by line as it becomes available
while True:
line = process.stdout.readline()
if not line:
break # No more output
print(line.strip())
with gr.Blocks(title="Axolotl Launcher") as demo: with gr.Blocks(title="Axolotl Launcher") as demo:
gr.Markdown( gr.Markdown(
""" """
@@ -63,34 +45,36 @@ with gr.Blocks(title="Axolotl Launcher") as demo:
Fill out the required fields below to create a training run. Fill out the required fields below to create a training run.
""" """
) )
base_model_name = gr.Textbox( with gr.Row():
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model" base_model_name = gr.Textbox(
) "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
)
mode = gr.Radio(
choices=["Full finetune", "QLoRA", "LoRA"],
label="Training mode",
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
)
with gr.Row(): with gr.Row():
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset") dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
dataset_type_name = gr.Dropdown( dataset_type_name = gr.Dropdown(
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca" choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
) )
with gr.Row(): with gr.Accordion("Hyperparameters", open=False):
learning_rate = gr.Number(0.000001, label="Learning rate") gr.Markdown("Choose hyperparameters")
gradient_accumulation_steps_count = gr.Number( with gr.Row():
1, label="Gradient accumulation steps" learning_rate = gr.Number(0.000001, label="Learning rate")
) gradient_accumulation_steps_count = gr.Number(
val_set_size_count = gr.Number(0, label="Validation size") 1, label="Gradient accumulation steps"
)
val_set_size_count = gr.Number(0, label="Validation size")
with gr.Row(): with gr.Row():
micro_batch_size_count = gr.Number(1, label="Micro batch size") micro_batch_size_count = gr.Number(1, label="Micro batch size")
sequence_length = gr.Number(1024, label="Sequence length") sequence_length = gr.Number(1024, label="Sequence length")
num_epochs_count = gr.Number(1, label="Epochs") num_epochs_count = gr.Number(1, label="Epochs")
output_dir_path = gr.Textbox("./model-out", label="Output directory") output_dir_path = gr.Textbox("./model-out", label="Output directory")
mode = gr.Radio(
choices=["Full finetune", "QLoRA", "LoRA"],
value="Full finetune",
label="Training mode",
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
)
create_config = gr.Button("Create config") create_config = gr.Button("Create config")
output = gr.TextArea(label="Generated config") output = gr.TextArea(label="Generated config")
@@ -111,7 +95,4 @@ with gr.Blocks(title="Axolotl Launcher") as demo:
outputs=output, outputs=output,
) )
start_training = gr.Button("Start training") demo.launch()
start_training.click(create_training_job)
demo.launch(share=True)