various bugfixes

This commit is contained in:
Wing Lian
2023-04-19 17:04:34 -04:00
parent 2624bc2f11
commit 94f5e415a3
6 changed files with 63 additions and 10 deletions

View File

@@ -102,13 +102,20 @@ def load_model(
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
except Exception as e:
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -148,7 +155,7 @@ def load_model(
model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp:
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
if cfg.load_4bit: