From 964d858da007ff5a06e91aa8bae9a22e3ff8c043 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 Nov 2023 21:34:22 -0400 Subject: [PATCH] fix model parallel (#816) --- src/axolotl/utils/models.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cc83840ba..8848e9503 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -442,14 +442,7 @@ def load_model( if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") - if ( - torch.cuda.device_count() > 1 - and int(os.getenv("WORLD_SIZE", "1")) > 1 - and (cfg.load_in_4bit) - ): - # llama is PROBABLY model parallelizable, but the default isn't that it is - # so let's only set it for the 4bit, see - # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133 + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(model, "is_parallelizable", True) setattr(model, "model_parallel", True)