WIP: Rely on cfg.inference

This commit is contained in:
Angainor Development
2023-06-09 08:49:32 +02:00
committed by GitHub
parent 193c73bce0
commit 813cfa4c14

View File

@@ -80,8 +80,7 @@ def load_model(
model_type, model_type,
tokenizer, tokenizer,
cfg, cfg,
adapter="lora", adapter="lora"
inference=False,
): ):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
@@ -95,7 +94,7 @@ def load_model(
) )
if is_llama_derived_model and cfg.flash_attention: if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False: if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention") logging.info("patching with flash attention")
@@ -402,7 +401,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
is_trainable=True, is_trainable=not cfg.inference,
device_map=cfg.device_map, device_map=cfg.device_map,
# torch_dtype=torch.float16, # torch_dtype=torch.float16,
) )