diff --git a/scripts/finetune.py b/scripts/finetune.py index fa09f401a..a4b145f02 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -37,6 +37,7 @@ from axolotl.prompt_tokenizers import ( ShareGPTPromptTokenizingStrategy, LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, ) from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter @@ -56,7 +57,15 @@ def setup_wandb_env_vars(cfg): os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id -def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False): +def load_model( + base_model, + base_model_config, + model_type, + tokenizer_type, + cfg, + adapter="lora", + inference: bool = False, +): # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit tokenizer = None @@ -67,13 +76,17 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn + logging.info("patching with flash attention") replace_llama_attn_with_flash_attn() - torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32, + torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) try: if cfg.load_4bit: - from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model + from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( + replace_peft_model_with_int4_lora_model, + ) + replace_peft_model_with_int4_lora_model() from peft import ( @@ -92,18 +105,26 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a from huggingface_hub import snapshot_download cache_model_path = Path(snapshot_download(base_model)) - files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin')) + files = ( + list(cache_model_path.glob("*.pt")) + + list(cache_model_path.glob("*.safetensors")) + + list(cache_model_path.glob("*.bin")) + ) if len(files) > 0: model_path = str(files[0]) else: - logging.warning("unable to find a cached model file, this will likely fail...") + logging.warning( + "unable to find a cached model file, this will likely fail..." + ) model_path = str(cache_model_path) model, tokenizer = load_llama_model_4bit_low_ram( base_model_config if base_model_config else base_model, model_path, device_map=cfg.device_map, groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1, - is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True, + is_v1_model=cfg.gptq_model_v1 + if cfg.gptq_model_v1 is not None + else True, ) load_in_8bit = False elif is_llama_derived_model: @@ -120,7 +141,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a torch_dtype=torch_dtype, device_map=cfg.device_map, ) - except: + except Exception as e: + logging.error( + "Exception raised attempting to load model, retrying with AutoModelForCausalLM" + ) + logging.exception(e) model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, @@ -145,7 +170,6 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": tokenizer.add_special_tokens({"pad_token": "[PAD]"}) os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -165,7 +189,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a ) if cfg.lora_model_dir: - model = PeftModel.from_pretrained(model, cfg.lora_model_dir, device_map = cfg.device_map, torch_dtype=torch.float16) + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, + ) else: model = get_peft_model(model, lora_config) @@ -174,9 +203,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if cfg.load_4bit: # Scales to half - logging.info('Fitting 4bit scales and zeros to half') + logging.info("Fitting 4bit scales and zeros to half") for n, m in model.named_modules(): - if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)): + if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str( + type(m) + ): if hasattr(m, "is_v1_model") and m.is_v1_model: m.zeros = m.zeros.half() m.scales = m.scales.half() @@ -236,37 +267,44 @@ def check_dataset_labels(dataset, tokenizer): def do_inference(cfg, model, tokenizer): - tokenizer.add_special_tokens({'unk_token': ''}) - tokenizer.add_special_tokens({'bos_token': ''}) - tokenizer.add_special_tokens({'eos_token': ''}) + tokenizer.add_special_tokens({"unk_token": ""}) + tokenizer.add_special_tokens({"bos_token": ""}) + tokenizer.add_special_tokens({"eos_token": ""}) instruction = "Tell me a joke about dromedaries." input = "" - prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input) + prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format( + instruction=instruction, input=input + ) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): # gc = GenerationConfig() # TODO swap out and use this - generated = model.generate(inputs=batch["input_ids"].to("cuda"), - do_sample=True, use_cache=True, - repetition_penalty=1.1, - max_new_tokens=100, - temperature=0.9, - top_p=0.95, - top_k=40, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False) - print(tokenizer.decode(generated['sequences'].cpu().tolist()[0])) + generated = model.generate( + inputs=batch["input_ids"].to("cuda"), + do_sample=True, + use_cache=True, + repetition_penalty=1.1, + max_new_tokens=100, + temperature=0.9, + top_p=0.95, + top_k=40, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) def choose_config(path: Path): yaml_files = [file for file in path.glob("*.yml")] if not yaml_files: - raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?") + raise ValueError( + "No YAML config files found in the specified directory. Are you using a .yml extension?" + ) print("Choose a YAML file:") for idx, file in enumerate(yaml_files): @@ -376,6 +414,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): return trainer + def train( config: Path = Path("configs/"), prepare_ds_only: bool = False, @@ -420,7 +459,13 @@ def train( # Load the model and tokenizer logging.info("loading model, tokenizer, and lora_config...") model, tokenizer, lora_config = load_model( - cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs) + cfg.base_model, + cfg.base_model_config, + cfg.model_type, + cfg.tokenizer_type, + cfg, + adapter=cfg.adapter, + inference=("inference" in kwargs), ) if "inference" in kwargs: @@ -428,10 +473,26 @@ def train( do_inference(cfg, model, tokenizer) return - max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len - max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len - ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest()) - prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + max_packed_sequence_len = ( + cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + ) + max_packed_sequence_len = min( + max_packed_sequence_len, cfg.sequence_len + ) # make sure we don't accidentally set it larger than sequence_len + ds_hash = str( + md5( + ( + str(max_packed_sequence_len) + + "@" + + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + ).encode("utf-8") + ).hexdigest() + ) + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + ) if any(prepared_ds_path.glob("*")): logging.info("Loading prepared dataset from disk...") @@ -464,9 +525,18 @@ def train( ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + elif d.type == "oasst": + ds_strategy = OpenAssistantPromptTokenizingStrategy( + AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) elif d.type == "gpteacher": ds_strategy = GPTeacherPromptTokenizingStrategy( - GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + GPTeacherPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) @@ -476,13 +546,17 @@ def train( ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + else: + logging.error(f"unhandled prompt tokenization strategy: {d.type}") constant_len_dataset = ConstantLengthDataset( - tokenizer, datasets, seq_length=max_packed_sequence_len, + tokenizer, + datasets, + seq_length=max_packed_sequence_len, ) logging.info("merging, packing, shuffling, and splitting master dataset") - dataset = Dataset.from_list( - [_ for _ in constant_len_dataset] - ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) + dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split( + test_size=cfg.val_set_size, shuffle=True, seed=42 + ) if cfg.local_rank == 0: logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}") @@ -525,7 +599,9 @@ def train( if cfg.local_rank == 0: # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + logging.info( + f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" + ) model.save_pretrained(cfg.output_dir) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 13d386cb4..8bbcfaaba 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC): pass -class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): +class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + raise NotImplementedError + def tokenize_prompt(self, prompt): - full_prompt = self._tokenize_full_prompt(prompt) + instruction, input, response = self.parse_instruction_fields(prompt) + full_prompt = self._build_full_prompt(instruction, input, response) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: user_prompt = self.prompter.build_prompt( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", + instruction, + input, ) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) user_prompt_len = len(tokenized_user_prompt["input_ids"]) @@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _tokenize_full_prompt(self, prompt): + def _build_full_prompt(self, instruction, input, response): return self.prompter.build_prompt( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", - prompt["output"], + instruction, + input, + response, ) def _tokenize(self, prompt, add_eos_token=True): @@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): return result -class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy): - def _tokenize_full_prompt(self, prompt): - return self.prompter.build_prompt( +class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( prompt["instruction"], - prompt["input"], + prompt["input"] if "input" in prompt else "", + prompt["output"], + ) + + +class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["INSTRUCTION"], + "", + prompt["RESPONSE"], + ) + + +class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["instruction"], + prompt["input"] if "input" in prompt else "", prompt["response"], )