diff --git a/scripts/extract_lora.py b/scripts/extract_lora.py deleted file mode 100644 index be88c5705..000000000 --- a/scripts/extract_lora.py +++ /dev/null @@ -1,163 +0,0 @@ -# import logging -# import os -# import random -# import signal -# import sys -# from pathlib import Path - -# import fire -# import torch -# import yaml -# from addict import Dict - -# from peft import set_peft_model_state_dict, get_peft_model_state_dict - -# # add src to the pythonpath so we don't need to pip install this -# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -# src_dir = os.path.join(project_root, "src") -# sys.path.insert(0, src_dir) - -# from axolotl.utils.data import load_prepare_datasets -# from axolotl.utils.models import load_model -# from axolotl.utils.trainer import setup_trainer -# from axolotl.utils.wandb import setup_wandb_env_vars - -# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) - - -# def choose_device(cfg): -# def get_device(): -# if torch.cuda.is_available(): -# return "cuda" -# else: -# try: -# if torch.backends.mps.is_available(): -# return "mps" -# except: -# return "cpu" - -# cfg.device = get_device() -# if cfg.device == "cuda": -# cfg.device_map = {"": cfg.local_rank} -# else: -# cfg.device_map = {"": cfg.device} - - -# def choose_config(path: Path): -# yaml_files = [file for file in path.glob("*.yml")] - -# if not yaml_files: -# raise ValueError( -# "No YAML config files found in the specified directory. Are you using a .yml extension?" -# ) - -# print("Choose a YAML file:") -# for idx, file in enumerate(yaml_files): -# print(f"{idx + 1}. {file}") - -# chosen_file = None -# while chosen_file is None: -# try: -# choice = int(input("Enter the number of your choice: ")) -# if 1 <= choice <= len(yaml_files): -# chosen_file = yaml_files[choice - 1] -# else: -# print("Invalid choice. Please choose a number from the list.") -# except ValueError: -# print("Invalid input. Please enter a number.") - -# return chosen_file - - -# def save_latest_checkpoint_as_lora( -# config: Path = Path("configs/"), -# prepare_ds_only: bool = False, -# **kwargs, -# ): -# if Path(config).is_dir(): -# config = choose_config(config) - -# # load the config from the yaml file -# with open(config, "r") as f: -# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) -# # if there are any options passed in the cli, if it is something that seems valid from the yaml, -# # then overwrite the value -# cfg_keys = dict(cfg).keys() -# for k in kwargs: -# if k in cfg_keys: -# # handle booleans -# if isinstance(cfg[k], bool): -# cfg[k] = bool(kwargs[k]) -# else: -# cfg[k] = kwargs[k] - -# # setup some derived config / hyperparams -# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size -# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) -# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) -# assert cfg.local_rank == 0, "Run this with only one device!" - -# choose_device(cfg) -# cfg.ddp = False - -# if cfg.device == "mps": -# cfg.load_in_8bit = False -# cfg.tf32 = False -# if cfg.bf16: -# cfg.fp16 = True -# cfg.bf16 = False - -# # Load the model and tokenizer -# logging.info("loading model, tokenizer, and lora_config...") -# model, tokenizer, lora_config = load_model( -# cfg.base_model, -# cfg.base_model_config, -# cfg.model_type, -# cfg.tokenizer_type, -# cfg, -# adapter=cfg.adapter, -# inference=True, -# ) - -# model.config.use_cache = False - -# if torch.__version__ >= "2" and sys.platform != "win32": -# logging.info("Compiling torch model") -# model = torch.compile(model) - -# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] -# if len(possible_checkpoints) > 0: -# sorted_paths = sorted( -# possible_checkpoints, key=lambda path: int(path.split("-")[-1]) -# ) -# resume_from_checkpoint = sorted_paths[-1] -# else: -# raise FileNotFoundError("Checkpoints folder not found") - -# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin") - -# assert os.path.exists(pytorch_bin_path), "Bin not found" - -# logging.info(f"Loading {pytorch_bin_path}") -# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu") - -# # d = get_peft_model_state_dict(model) -# print(model.load_state_dict(adapters_weights)) -# # with open('b.log', "w") as f: -# # f.write(str(d.keys())) -# assert False - -# print((adapters_weights.keys())) -# with open("a.log", "w") as f: -# f.write(str(adapters_weights.keys())) -# assert False - -# logging.info("Setting peft model state dict") -# set_peft_model_state_dict(model, adapters_weights) - -# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}") -# model.save_pretrained(cfg.output_dir) - - -# if __name__ == "__main__": -# fire.Fire(save_latest_checkpoint_as_lora)