testing mpt triton
This commit is contained in:
@@ -8,7 +8,7 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedModel,
|
PreTrainedModel, AutoConfig,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -116,8 +116,14 @@ def load_model(
|
|||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
|
)
|
||||||
|
config.attn_config['attn_impl'] = 'triton'
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
def setup_wandb_env_vars(cfg):
|
def setup_wandb_env_vars(cfg):
|
||||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||||
|
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||||
|
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user