diff --git a/scripts/finetune.py b/scripts/finetune.py index 8a458890c..f74ee9b23 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -165,7 +165,7 @@ def train( cfg_keys = cfg.keys() for k, _ in kwargs.items(): # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or cfg.strict is False: + if k in cfg_keys or not cfg.strict: # handle booleans if isinstance(cfg[k], bool): cfg[k] = bool(kwargs[k]) @@ -205,8 +205,8 @@ def train( logging.info(f"loading tokenizer... {tokenizer_config}") tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) - if check_not_in( - ["inference", "shard", "merge_lora"], kwargs + if ( + check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference ): # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH @@ -234,7 +234,6 @@ def train( tokenizer, cfg, adapter=cfg.adapter, - inference=("inference" in kwargs), ) if "merge_lora" in kwargs and cfg.adapter is not None: @@ -247,7 +246,7 @@ def train( model.save_pretrained(str(Path(cfg.output_dir) / "merged")) return - if "inference" in kwargs: + if cfg.inference: logging.info("calling do_inference function") inf_kwargs: Dict[str, Any] = {} if "prompter" in kwargs: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b79f116fa..506c98e00 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -77,15 +77,9 @@ def load_tokenizer( def load_model( - base_model, - base_model_config, - model_type, - tokenizer, - cfg, - adapter="lora", - inference=False, + base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" ): - # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ Load a model from a base model and a model type. """ @@ -98,7 +92,7 @@ def load_model( ) if cfg.is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and not cfg.inference: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention") @@ -439,6 +433,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, + is_trainable=not cfg.inference, device_map=cfg.device_map, # torch_dtype=torch.float16, )