From 813cfa4c14f990c53ed42e9decd84b3e41a91102 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Fri, 9 Jun 2023 08:49:32 +0200 Subject: [PATCH] WIP: Rely on cfg.inference --- src/axolotl/utils/models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b5d5124cb..c3f988e52 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -80,8 +80,7 @@ def load_model( model_type, tokenizer, cfg, - adapter="lora", - inference=False, + adapter="lora" ): # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ @@ -95,7 +94,7 @@ def load_model( ) if is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and cfg.inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention") @@ -402,7 +401,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, - is_trainable=True, + is_trainable=not cfg.inference, device_map=cfg.device_map, # torch_dtype=torch.float16, )