testing mpt triton

This commit is contained in:
Wing Lian
2023-05-09 20:57:40 -04:00
parent a27d594788
commit e2e68c3965
2 changed files with 10 additions and 2 deletions

View File

@@ -8,7 +8,7 @@ import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedModel, AutoConfig,
)
try:
from transformers import (
@@ -116,8 +116,14 @@ def load_model(
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
config.attn_config['attn_impl'] = 'triton'
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,

View File

@@ -2,7 +2,9 @@ import os
def setup_wandb_env_vars(cfg):
if cfg.wandb_project and len(cfg.wandb_project) > 0:
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: