tweaks to data loading, 8 bit adam, accelerate and deepspeed

This commit is contained in:
Wing Lian
2023-04-22 16:25:23 -04:00
parent 4f2584f2dc
commit 097d367af6
4 changed files with 87 additions and 19 deletions

View File

@@ -101,12 +101,19 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
if not cfg.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,