diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b5d5124cb..c3f988e52 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -80,8 +80,7 @@ def load_model( model_type, tokenizer, cfg, - adapter="lora", - inference=False, + adapter="lora" ): # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ @@ -95,7 +94,7 @@ def load_model( ) if is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and cfg.inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention") @@ -402,7 +401,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, - is_trainable=True, + is_trainable=not cfg.inference, device_map=cfg.device_map, # torch_dtype=torch.float16, )