From 75e4fc2825681e5d2e7af9357eebc5859acb313e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 1 Nov 2023 01:45:36 -0400 Subject: [PATCH] wip more tp fixes --- src/axolotl/core/trainer_builder.py | 5 ++++- src/axolotl/utils/config.py | 4 ++++ src/axolotl/utils/models.py | 32 ++++++++++++++--------------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 55a1764fc..195c5c426 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -371,7 +371,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer_kwargs, trainer_cls def hook_post_create_trainer(self, trainer): - # TODO + if self.cfg.tensor_parallel: + trainer.model = trainer.accelerator.prepare_model( + trainer.model, device_placement=True + ) return trainer def get_callbacks(self): diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 81660ae65..893b4b723 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -369,6 +369,10 @@ def validate_config(cfg): "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) + if cfg.tensor_parallel and cfg.gradient_checkpointing: + raise ValueError( + "TensorParallelPreTrainedModel does not support gradient checkpointing" + ) # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ccaf41f35..8f696c12a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -5,7 +5,6 @@ import math import os from typing import Optional, Tuple # noqa: F401 -import accelerate import bitsandbytes as bnb import tensor_parallel as tp import torch @@ -31,6 +30,7 @@ from transformers import ( # noqa: F401 from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_distributed LOG = logging.getLogger("axolotl") @@ -328,19 +328,14 @@ def load_model( **model_kwargs, ) elif cfg.tensor_parallel: - config = AutoConfig.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( base_model, - trust_remote_code=cfg.trust_remote_code or False, - ) - with accelerate.init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config=config, - trust_remote_code=cfg.trust_remote_code or False, - ).half() - model = tp.TensorParallelPreTrainedModel( - model, - sharded=False, + torch_dtype=cfg.torch_dtype, + low_cpu_mem_usage=True, + offload_state_dict=True, ) + model = tp.tensor_parallel(model, distributed=is_distributed()) + model.hf_device_map = tp.infer_sharded_device_map(model) else: config = AutoConfig.from_pretrained( base_model, @@ -473,12 +468,17 @@ def load_model( load_file = torch.load try: with open( - hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r" + hf_hub_download(base_model, "pytorch_model.bin.index.json"), + "r", + encoding="utf=8", ) as index_file: shard_filenames = set(json.load(index_file)["weight_map"].values()) - except: + except Exception as err: # pylint: disable=broad-exception-caught + LOG.warning(err) with open( - hf_hub_download(base_model, "model.safetensors.index.json"), "r" + hf_hub_download(base_model, "model.safetensors.index.json"), + "r", + encoding="utf=8", ) as index_file: shard_filenames = set(json.load(index_file)["weight_map"].values()) load_file = load_safetensors_file @@ -492,7 +492,7 @@ def load_model( tp.convert_state_dict( # <- tensor_parallel helper function. load_file( shard_path - ), # Creates a tensor_parallel checkpoint form a normal one + ), # Creates a tensor_parallel checkpoint form a normal one model.tensor_parallel_config, world_size=torch.cuda.device_count(), for_pretrained=True,