8bit and deepspeed changes

This commit is contained in:
Wing Lian
2023-04-30 06:50:35 -04:00
parent 4dbef0941f
commit 9190ada23a
2 changed files with 11 additions and 16 deletions

View File

@@ -101,19 +101,12 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
if not cfg.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,