diff --git a/scripts/finetune.py b/scripts/finetune.py index 915ba1de1..ff2f87eb1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -191,7 +191,9 @@ def train( if cfg.debug: logging.info("check_dataset_labels...") check_dataset_labels( - train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]), + train_dataset.select( + [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + ), tokenizer, ) @@ -218,17 +220,20 @@ def train( logging.info("Starting trainer...") resume_from_checkpoint = cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] + possible_checkpoints = [ + str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") + ] if len(possible_checkpoints) > 0: - sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1])) + sorted_paths = sorted( + possible_checkpoints, key=lambda path: int(path.split("-")[-1]) + ) resume_from_checkpoint = sorted_paths[-1] - logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}") + logging.info( + f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" + ) trainer.train(resume_from_checkpoint=resume_from_checkpoint) - logging.info( - f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" - ) - + logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading trainer.save_pretrained(cfg.output_dir) diff --git a/setup.py b/setup.py index a183bcda1..134e4be66 100644 --- a/setup.py +++ b/setup.py @@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file: install_requires.append(r) setup( - name='axolotl', - version='0.1', + name="axolotl", + version="0.1", description="You know you're going to axolotl questions", - package_dir={'': 'src'}, + package_dir={"": "src"}, packages=find_packages(), install_requires=install_requires, extras_require={ - 'int4': [ + "int4": [ "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip", ], - 'int4_triton': [ + "int4_triton": [ "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip", ], - 'extras': [ - 'flash-attn', - 'deepspeed', - ] + "extras": [ + "flash-attn", + "deepspeed", + ], }, ) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index deab5e438..d9acf5715 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset): except InvalidDataException: pass + # TODO this isn't the best since it can't interleave datasets class ConstantLengthDataset(IterableDataset): """ @@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset): dataset (dataset.Dataset): Dataset with text files. seq_length (int): Length of token sequences to return. """ + def __init__( self, tokenizer, @@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset): : self.seq_length ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size(): + if ( + labels.size() == input_ids.size() + and attention_mask.size() == input_ids.size() + ): yield { "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, } else: - logging.warning("dropping batch due to tensor size mismatch") + logging.warning( + "dropping batch due to tensor size mismatch" + ) buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer_len = 0 @@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset): attention_mask.append(1) labels.append(self.concat_token_id) - input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype) + input_ids_with_concat = torch.tensor( + input_ids, dtype=self.tokens_dtype + ) attention_mask_with_concat = torch.tensor( attention_mask, dtype=self.tokens_dtype ) - labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype) + labels_with_concat = torch.tensor( + labels, dtype=self.tokens_dtype + ) buffer["input_ids"].append(input_ids_with_concat) buffer["attention_mask"].append(attention_mask_with_concat) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 167648618..00d8ecbf9 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str): - return ( - prompt["text"] - ) + def parse_instruction_fields(self, prompt) -> str: + return prompt["text"] def tokenize_prompt(self, prompt): instruction = self.parse_instruction_fields(prompt) @@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return tokenized_full_prompt def _build_full_prompt(self, instruction): - return self.prompter.build_prompt( - instruction - ) + return self.prompter.build_prompt(instruction) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): @@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): raise NotImplementedError def tokenize_prompt(self, prompt): - instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected) + ( + instruction, + input, + output, + reflection, + corrected, + ) = self.parse_instruction_fields(prompt) + full_prompt = self._build_full_prompt( + instruction, input, output, reflection, corrected + ) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: user_prompt = self.prompter.build_prompt( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 914cbd0de..3dc5d6433 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter): - def build_prompt( - self, - instruction: str - ) -> str: + def build_prompt(self, instruction: str) -> str: return instruction def get_response(self, output: str) -> str: @@ -75,7 +72,9 @@ class ReflectAlpacaPrompter: else: res = self.prompt_no_input.format(instruction=instruction) if output and reflection and corrected: - label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected) + label = self.agent_label.format( + output=output, reflection=reflection, corrected=corrected + ) res = f"{res}{label}" return res @@ -200,9 +199,13 @@ class ShareGPTPrompter: if len(parts) != 2: break parts[0] += sep - round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this + round_len = ( + len(tokenizer(rou)["input_ids"]) - 1 + ) # -1 ignores the bos_token generated for this # we have to strip the initial part, any dangling whitespace creates an additional ghost token - instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this + instruction_len = ( + len(tokenizer(parts[0].strip())["input_ids"]) - 1 + ) # -1 ignores the bos_token generated for this target[cur_len : cur_len + instruction_len] = [ IGNORE_TOKEN_ID ] * instruction_len @@ -212,7 +215,7 @@ class ShareGPTPrompter: break # Fix: Truncate the target to have the same length as input_ids - target = target[:len(tokenized_result["input_ids"])] + target = target[: len(tokenized_result["input_ids"])] # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) attention_mask = [ diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index aaf96bcb0..229cd9b98 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,8 +1,15 @@ import os -from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers import ( + Seq2SeqTrainer, + TrainerCallback, + TrainingArguments, + TrainerState, + TrainerControl, +) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + class SavePeftModelCallback(TrainerCallback): def on_save( self, @@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback): control: TrainerControl, **kwargs, ): - checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + checkpoint_folder = os.path.join( + args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b217b50d7..581b48a88 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,7 +2,13 @@ import logging from hashlib import md5 from pathlib import Path -from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets +from datasets import ( + load_from_disk, + load_dataset, + IterableDataset, + Dataset, + concatenate_datasets, +) from huggingface_hub import hf_hub_download from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset @@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): else: ds = load_dataset(d.path, streaming=True) else: - fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files) + fp = hf_hub_download( + repo_id=d.path, repo_type="dataset", filename=d.data_files + ) ds = load_dataset("json", data_files=fp, streaming=True, split=None) if not ds: raise Exception("unhandled dataset load") @@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): samples = samples + [i for i in d] dataset = Dataset.from_list(samples).shuffle(seed=42) if cfg.local_rank == 0: - logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") + logging.info( + f"Saving merged prepared dataset to disk... {prepared_ds_path}" + ) dataset.save_to_disk(prepared_ds_path) if cfg.max_packed_sequence_len is not None: @@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): dataset = Dataset.from_list([_ for _ in constant_len_dataset]) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: - logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards") - dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx) + logging.info( + f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" + ) + dataset = dataset.shard( + num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx + ) - dataset = dataset.train_test_split( - test_size=cfg.val_set_size, shuffle=False - ) + dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) train_dataset = dataset["train"] eval_dataset = dataset["test"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8c80b2621..2ca84b795 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -9,14 +9,18 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, + AutoConfig, ) + try: from transformers import ( LlamaForCausalLM, LlamaTokenizer, ) except: - logging.warning("This version of transformers does not support Llama. Consider upgrading.") + logging.warning( + "This version of transformers does not support Llama. Consider upgrading." + ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN @@ -40,7 +44,9 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit tokenizer = None - is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower()) + is_llama_derived_model = "llama" in base_model or ( + cfg.model_type and "llama" in cfg.model_type.lower() + ) if is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: @@ -49,11 +55,16 @@ def load_model( logging.info("patching with flash attention") replace_llama_attn_with_flash_attn() elif is_llama_derived_model and cfg.xformers_attention: - from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention + from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_attention, + ) + logging.info("patching with xformers attention") hijack_llama_attention() - torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 + torch_dtype = ( + torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 + ) try: if cfg.load_4bit: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( @@ -74,8 +85,12 @@ def load_model( try: snapshot_download_kwargs = {} if cfg.base_model_ignore_patterns: - snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns - cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs)) + snapshot_download_kwargs[ + "ignore_patterns" + ] = cfg.base_model_ignore_patterns + cache_model_path = Path( + snapshot_download(base_model, **snapshot_download_kwargs) + ) files = ( list(cache_model_path.glob("*.pt")) + list(cache_model_path.glob("*.safetensors")) @@ -116,8 +131,13 @@ def load_model( trust_remote_code=True if cfg.trust_remote_code is True else False, ) else: + config = AutoConfig.from_pretrained( + base_model, + trust_remote_code=True if cfg.trust_remote_code is True else False, + ) model = AutoModelForCausalLM.from_pretrained( base_model, + config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index 72916f037..b9b7e25be 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler): if self.last_epoch <= 0: lrs = [self.min_lr for base_lr in self.base_lrs] elif self.last_epoch < self.num_steps: - lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs] + lrs = [ + self.min_lr * (self.q ** (self.last_epoch - 1)) + for base_lr in self.base_lrs + ] else: lrs = [self.max_lr for base_lr in self.base_lrs] diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index b9ffb1e1b..f23ca8a92 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,6 +1,7 @@ from termcolor import colored import logging + def check_dataset_labels(dataset, tokenizer): # the dataset is already shuffled, so let's just check the first 5 elements for idx in range(5): @@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer): # Get the input_ids, labels, and attention_mask from the dataset input_ids = example["input_ids"] labels = example["labels"] - attention_mask =example["attention_mask"] + attention_mask = example["attention_mask"] # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 @@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer): ): decoded_input_token = tokenizer.decode(input_id) # Choose the color based on whether the label has the ignore value or not - color = ( - "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") - ) + color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") colored_token = colored(decoded_input_token, color) + colored( f"({label_id}, {mask}, {input_id})", "white" ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ef1ac95b..df52afa26 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.logging_steps is not None else max(min(int(0.005 * total_num_steps), 10), 1) ) - save_steps = ( - cfg.save_steps - if cfg.save_steps is not None - else min(int(0.05 * total_num_steps), 200) - ) - eval_steps = ( - cfg.eval_steps - if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0 - else save_steps - ) + save_steps = cfg.save_steps + eval_steps = cfg.eval_steps training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, - per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size, + per_device_eval_batch_size=cfg.eval_batch_size + if cfg.eval_batch_size is not None + else cfg.micro_batch_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", - save_strategy="steps", + save_strategy="steps" if save_steps else "epoch", eval_steps=eval_steps if cfg.val_set_size > 0 else None, save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, load_best_model_at_end=True - if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True + if cfg.val_set_size > 0 + and save_steps is not None + and save_steps % eval_steps == 0 + and cfg.load_in_8bit is not True else False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, optim=cfg.optimizer if cfg.optimizer else "adamw_hf", - lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", + lr_scheduler_type=cfg.lr_scheduler + if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") + else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, **training_arguments_kwargs, ) @@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.learning_rate, total_steps=total_num_steps, epochs=cfg.num_epochs, + div_factor=10, **lr_scheduler_kwargs, ) elif cfg.lr_scheduler == "log_sweep": @@ -191,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): data_collator_kwargs["pad_to_multiple_of"] = 8 callbacks = [] - if cfg.adapter == 'lora': + if cfg.adapter == "lora": callbacks.append(SavePeftModelCallback) trainer = transformers.Trainer( diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py index 1e805c6c6..992bb1a5f 100644 --- a/src/axolotl/utils/wandb.py +++ b/src/axolotl/utils/wandb.py @@ -2,7 +2,9 @@ import os def setup_wandb_env_vars(cfg): - if cfg.wandb_project and len(cfg.wandb_project) > 0: + if cfg.wandb_mode and cfg.wandb_mode == "offline": + os.environ["WANDB_MODE"] = cfg.wandb_mode + elif cfg.wandb_project and len(cfg.wandb_project) > 0: os.environ["WANDB_PROJECT"] = cfg.wandb_project cfg.use_wandb = True if cfg.wandb_watch and len(cfg.wandb_watch) > 0: