diff --git a/examples/lora-openllama-3b/config.yml b/examples/lora-openllama-3b/config.yml new file mode 100644 index 000000000..6665044e0 --- /dev/null +++ b/examples/lora-openllama-3b/config.yml @@ -0,0 +1,67 @@ +base_model: openlm-research/open_llama_3b_600bt_preview +base_model_config: openlm-research/open_llama_3b_600bt_preview +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +load_in_8bit: true +load_in_4bit: false +strict: false +push_dataset_to_hub: +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +adapter: lora +lora_model_dir: +sequence_len: 256 +max_packed_sequence_len: +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.0 +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj +lora_fan_in_fan_out: +wandb_project: +wandb_watch: +wandb_run_id: +wandb_log_model: +output_dir: ./lora-out +batch_size: 16 +micro_batch_size: 4 +num_epochs: 3 +optimizer: adamw_bnb_8bit +torchdistx_path: +lr_scheduler: cosine +learning_rate: 0.0002 +train_on_inputs: false +group_by_length: false +bf16: false +fp16: true +tf32: false +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: true +flash_attention: +gptq_groupsize: +gptq_model_v1: +warmup_steps: 10 +eval_steps: 50 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index a6d237a11..fd9dfc8d4 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -17,8 +17,8 @@ class AlpacaPrompter: system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" prompt_style = None - def __init__(self, prompt_style="instruct"): - self.prompt_style = prompt_style + def __init__(self, prompt_style=PromptStyle.instruct.value): + self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value self.match_prompt_style() def match_prompt_style(self): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5b243bec4..de04e9333 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -211,12 +211,12 @@ def load_model( try: if is_llama_derived_model and "LlamaTokenizer" in globals(): tokenizer = LlamaTokenizer.from_pretrained( - model, + base_model_config, trust_remote_code=True if cfg.trust_remote_code is True else False, ) else: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - model, + base_model_config, trust_remote_code=True if cfg.trust_remote_code is True else False, ) except: