8bit and deepspeed changes

This commit is contained in:
Wing Lian
2023-04-30 06:50:35 -04:00
parent 4dbef0941f
commit 9190ada23a
2 changed files with 11 additions and 16 deletions

View File

@@ -20,10 +20,12 @@
} }
}, },
"scheduler": { "scheduler": {
"type": "OneCycle", "type": "WarmupDecayLR",
"params": { "params": {
"cycle_min_lr": 1e-7, "warmup_min_lr": "auto",
"cycle_max_lr": 1e-4 "warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
} }
}, },
"zero_optimization": { "zero_optimization": {

View File

@@ -101,19 +101,12 @@ def load_model(
) )
load_in_8bit = False load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals(): elif is_llama_derived_model and "LlamaForCausalLM" in globals():
if not cfg.load_in_8bit: model = LlamaForCausalLM.from_pretrained(
model = LlamaForCausalLM.from_pretrained( base_model,
base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
device_map=cfg.device_map, torch_dtype=torch_dtype,
) device_map=cfg.device_map,
else: )
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
elif model_type: elif model_type:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,