WIP: Rely on cfg.inference
This commit is contained in:
committed by
GitHub
parent
193c73bce0
commit
813cfa4c14
@@ -80,8 +80,7 @@ def load_model(
|
|||||||
model_type,
|
model_type,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
adapter="lora",
|
adapter="lora"
|
||||||
inference=False,
|
|
||||||
):
|
):
|
||||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
"""
|
"""
|
||||||
@@ -95,7 +94,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_llama_derived_model and cfg.flash_attention:
|
if is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
logging.info("patching with flash attention")
|
||||||
@@ -402,7 +401,7 @@ def load_lora(model, cfg):
|
|||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=True,
|
is_trainable=not cfg.inference,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
# torch_dtype=torch.float16,
|
# torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user