fix model parallel (#816)
This commit is contained in:
@@ -442,14 +442,7 @@ def load_model(
|
|||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
if (
|
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
||||||
and (cfg.load_in_4bit)
|
|
||||||
):
|
|
||||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
|
||||||
# so let's only set it for the 4bit, see
|
|
||||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
|
||||||
setattr(model, "is_parallelizable", True)
|
setattr(model, "is_parallelizable", True)
|
||||||
setattr(model, "model_parallel", True)
|
setattr(model, "model_parallel", True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user