try to detect accelerate and only use device_map=None in that case (#373)
This commit is contained in:
committed by
GitHub
parent
2dafa730ef
commit
094fc2c6e6
@@ -30,6 +30,12 @@ def choose_device(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.device_map = {"": cfg.device}
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
# in `accelerate launch`, we need to not pass through any device map and let
|
||||||
|
# accelerate figure out which parts of the model to put on which gpu
|
||||||
|
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
||||||
|
if accelerate_vars:
|
||||||
|
cfg.device_map = None
|
||||||
|
|
||||||
|
|
||||||
def normalize_config(cfg):
|
def normalize_config(cfg):
|
||||||
# setup some derived config / hyperparams
|
# setup some derived config / hyperparams
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ def load_model(
|
|||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -269,6 +270,7 @@ def load_model(
|
|||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -299,6 +301,7 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -312,6 +315,7 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
|||||||
Reference in New Issue
Block a user