8bit and deepspeed changes

This commit is contained in:
Wing Lian
2023-04-30 06:50:35 -04:00
parent 4dbef0941f
commit 9190ada23a
2 changed files with 11 additions and 16 deletions

View File

@@ -20,10 +20,12 @@
}
},
"scheduler": {
"type": "OneCycle",
"type": "WarmupDecayLR",
"params": {
"cycle_min_lr": 1e-7,
"cycle_max_lr": 1e-4
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {

View File

@@ -101,19 +101,12 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
if not cfg.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,