diff --git a/ui/main.py b/ui/main.py new file mode 100644 index 000000000..8da0145d9 --- /dev/null +++ b/ui/main.py @@ -0,0 +1,95 @@ +""" +This module is used to launch Axolotl with user defined configurations. +""" +import gradio as gr +import yaml + + +def config( + base_model, + dataset, + dataset_type, + learn_rate, + gradient_accumulation_steps, + micro_batch_size, + seq_length, + num_epochs, + output_dir, + val_size, +): + """ + This function generates a configuration dictionary and saves it as a yaml file. + """ + config_dict = { + "base_model": base_model, + "datasets": [{"path": dataset, "type": dataset_type}], + "learning_rate": learn_rate, + "gradient_accumulation_steps": gradient_accumulation_steps, + "micro_batch_size": micro_batch_size, + "sequence_len": seq_length, + "num_epochs": num_epochs, + "output_dir": output_dir, + "val_set_size": val_size, + } + with open("config.yaml", "w", encoding="utf-8") as file: + yaml.dump(config_dict, file) + print(config_dict) + return yaml.dump(config_dict) + + +with gr.Blocks(title="Axolotl Launcher") as demo: + gr.Markdown( + """ + # Axolotl Launcher + Fill out the required fields below to create a training run. + """ + ) + base_model_name = gr.Textbox( + "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model" + ) + with gr.Row(): + dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset") + dataset_type_name = gr.Dropdown( + choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca" + ) + with gr.Row(): + learning_rate = gr.Number(0.000001, label="Learning rate") + gradient_accumulation_steps_count = gr.Number( + 1, label="Gradient accumulation steps" + ) + val_set_size_count = gr.Number(0, label="Validation size") + + with gr.Row(): + micro_batch_size_count = gr.Number(1, label="Micro batch size") + sequence_length = gr.Number(1024, label="Sequence length") + num_epochs_count = gr.Number(1, label="Epochs") + + output_dir_path = gr.Textbox("./model-out", label="Output directory") + + mode = gr.Radio( + choices=["Full finetune", "QLoRA", "LoRA"], + value="Full finetune", + label="Training mode", + info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit", + ) + + create_config = gr.Button("Create config") + output = gr.TextArea(label="Generated config") + create_config.click( + config, + inputs=[ + base_model_name, + dataset_path, + dataset_type_name, + learning_rate, + gradient_accumulation_steps_count, + micro_batch_size_count, + sequence_length, + num_epochs_count, + output_dir_path, + val_set_size_count, + ], + outputs=output, + ) + +demo.launch(share=True)