From e2e68c396521790d4693d52be538ecaf8e907e2f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 May 2023 20:57:40 -0400 Subject: [PATCH] testing mpt triton --- src/axolotl/utils/models.py | 8 +++++++- src/axolotl/utils/wandb.py | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8c80b2621..76782be4f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -8,7 +8,7 @@ import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, - PreTrainedModel, + PreTrainedModel, AutoConfig, ) try: from transformers import ( @@ -116,8 +116,14 @@ def load_model( trust_remote_code=True if cfg.trust_remote_code is True else False, ) else: + config = AutoConfig.from_pretrained( + base_model, + trust_remote_code=True if cfg.trust_remote_code is True else False, + ) + config.attn_config['attn_impl'] = 'triton' model = AutoModelForCausalLM.from_pretrained( base_model, + config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py index 1e805c6c6..992bb1a5f 100644 --- a/src/axolotl/utils/wandb.py +++ b/src/axolotl/utils/wandb.py @@ -2,7 +2,9 @@ import os def setup_wandb_env_vars(cfg): - if cfg.wandb_project and len(cfg.wandb_project) > 0: + if cfg.wandb_mode and cfg.wandb_mode == "offline": + os.environ["WANDB_MODE"] = cfg.wandb_mode + elif cfg.wandb_project and len(cfg.wandb_project) > 0: os.environ["WANDB_PROJECT"] = cfg.wandb_project cfg.use_wandb = True if cfg.wandb_watch and len(cfg.wandb_watch) > 0: