testing mpt triton

This commit is contained in:
Wing Lian
2023-05-09 20:57:40 -04:00
parent a27d594788
commit e2e68c3965
2 changed files with 10 additions and 2 deletions

View File

@@ -8,7 +8,7 @@ import transformers
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
PreTrainedModel, PreTrainedModel, AutoConfig,
) )
try: try:
from transformers import ( from transformers import (
@@ -116,8 +116,14 @@ def load_model(
trust_remote_code=True if cfg.trust_remote_code is True else False, trust_remote_code=True if cfg.trust_remote_code is True else False,
) )
else: else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
config.attn_config['attn_impl'] = 'triton'
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,

View File

@@ -2,7 +2,9 @@ import os
def setup_wandb_env_vars(cfg): def setup_wandb_env_vars(cfg):
if cfg.wandb_project and len(cfg.wandb_project) > 0: if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0: