diff --git a/ui/main.py b/ui/main.py index 4770e22ec..900ac0bb0 100644 --- a/ui/main.py +++ b/ui/main.py @@ -1,7 +1,6 @@ """ This module is used to launch Axolotl with user defined configurations. """ -import subprocess import gradio as gr import yaml @@ -39,23 +38,6 @@ def config( return yaml.dump(config_dict) -def create_training_job(): - # Start a long-running process - process = subprocess.Popen( - ["accelerate launch -m axolotl.cli.train config.yml"], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - - # Read the output line by line as it becomes available - while True: - line = process.stdout.readline() - if not line: - break # No more output - print(line.strip()) - - with gr.Blocks(title="Axolotl Launcher") as demo: gr.Markdown( """ @@ -63,34 +45,36 @@ with gr.Blocks(title="Axolotl Launcher") as demo: Fill out the required fields below to create a training run. """ ) - base_model_name = gr.Textbox( - "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model" - ) + with gr.Row(): + base_model_name = gr.Textbox( + "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model" + ) + + mode = gr.Radio( + choices=["Full finetune", "QLoRA", "LoRA"], + label="Training mode", + info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit", + ) with gr.Row(): dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset") dataset_type_name = gr.Dropdown( choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca" ) - with gr.Row(): - learning_rate = gr.Number(0.000001, label="Learning rate") - gradient_accumulation_steps_count = gr.Number( - 1, label="Gradient accumulation steps" - ) - val_set_size_count = gr.Number(0, label="Validation size") + with gr.Accordion("Hyperparameters", open=False): + gr.Markdown("Choose hyperparameters") + with gr.Row(): + learning_rate = gr.Number(0.000001, label="Learning rate") + gradient_accumulation_steps_count = gr.Number( + 1, label="Gradient accumulation steps" + ) + val_set_size_count = gr.Number(0, label="Validation size") - with gr.Row(): - micro_batch_size_count = gr.Number(1, label="Micro batch size") - sequence_length = gr.Number(1024, label="Sequence length") - num_epochs_count = gr.Number(1, label="Epochs") + with gr.Row(): + micro_batch_size_count = gr.Number(1, label="Micro batch size") + sequence_length = gr.Number(1024, label="Sequence length") + num_epochs_count = gr.Number(1, label="Epochs") - output_dir_path = gr.Textbox("./model-out", label="Output directory") - - mode = gr.Radio( - choices=["Full finetune", "QLoRA", "LoRA"], - value="Full finetune", - label="Training mode", - info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit", - ) + output_dir_path = gr.Textbox("./model-out", label="Output directory") create_config = gr.Button("Create config") output = gr.TextArea(label="Generated config") @@ -111,7 +95,4 @@ with gr.Blocks(title="Axolotl Launcher") as demo: outputs=output, ) - start_training = gr.Button("Start training") - start_training.click(create_training_job) - -demo.launch(share=True) +demo.launch()