From 094fc2c6e6a44cff58a5cefe798a821de009fa74 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Sat, 12 Aug 2023 21:32:07 -0700 Subject: [PATCH] try to detect accelerate and only use device_map=None in that case (#373) --- src/axolotl/utils/config.py | 6 ++++++ src/axolotl/utils/models.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9873687aa..b7bbab668 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -30,6 +30,12 @@ def choose_device(cfg): else: cfg.device_map = {"": cfg.device} + # in `accelerate launch`, we need to not pass through any device map and let + # accelerate figure out which parts of the model to put on which gpu + accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")] + if accelerate_vars: + cfg.device_map = None + def normalize_config(cfg): # setup some derived config / hyperparams diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5ed140b07..1f36c50db 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -235,6 +235,7 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, config=config, + device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, @@ -269,6 +270,7 @@ def load_model( elif model_type and not cfg.trust_remote_code: model = getattr(transformers, model_type).from_pretrained( base_model, + device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, @@ -299,6 +301,7 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=config, + device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, @@ -312,6 +315,7 @@ def load_model( LOG.exception(err) model = AutoModelForCausalLM.from_pretrained( base_model, + device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype,