Merge pull request #159 from AngainorDev/patch-1

Fix training over existing lora
This commit is contained in:
NanoCode012
2023-06-12 20:27:11 +09:00
committed by GitHub
2 changed files with 8 additions and 14 deletions

View File

@@ -165,7 +165,7 @@ def train(
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
for k, _ in kwargs.items(): for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already # if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False: if k in cfg_keys or not cfg.strict:
# handle booleans # handle booleans
if isinstance(cfg[k], bool): if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k]) cfg[k] = bool(kwargs[k])
@@ -205,8 +205,8 @@ def train(
logging.info(f"loading tokenizer... {tokenizer_config}") logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in( if (
["inference", "shard", "merge_lora"], kwargs check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
@@ -234,7 +234,6 @@ def train(
tokenizer, tokenizer,
cfg, cfg,
adapter=cfg.adapter, adapter=cfg.adapter,
inference=("inference" in kwargs),
) )
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -247,7 +246,7 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged")) model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if "inference" in kwargs: if cfg.inference:
logging.info("calling do_inference function") logging.info("calling do_inference function")
inf_kwargs: Dict[str, Any] = {} inf_kwargs: Dict[str, Any] = {}
if "prompter" in kwargs: if "prompter" in kwargs:

View File

@@ -77,15 +77,9 @@ def load_tokenizer(
def load_model( def load_model(
base_model, base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
base_model_config,
model_type,
tokenizer,
cfg,
adapter="lora",
inference=False,
): ):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """
@@ -98,7 +92,7 @@ def load_model(
) )
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False: if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention") logging.info("patching with flash attention")
@@ -439,6 +433,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
is_trainable=not cfg.inference,
device_map=cfg.device_map, device_map=cfg.device_map,
# torch_dtype=torch.float16, # torch_dtype=torch.float16,
) )