add support for trust_remote_code for mpt models

This commit is contained in:
Wing Lian
2023-05-08 12:07:27 -04:00
parent 709be5af81
commit a125693122
3 changed files with 68 additions and 0 deletions

View File

@@ -113,6 +113,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
model = AutoModelForCausalLM.from_pretrained(
@@ -120,6 +121,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except Exception as e:
logging.error(
@@ -131,6 +133,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
if not tokenizer: