wip more tp fixes

This commit is contained in:
Wing Lian
2023-11-01 01:45:36 -04:00
parent e13c2fd6b1
commit 75e4fc2825
3 changed files with 24 additions and 17 deletions

View File

@@ -371,7 +371,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
if self.cfg.tensor_parallel:
trainer.model = trainer.accelerator.prepare_model(
trainer.model, device_placement=True
)
return trainer
def get_callbacks(self):

View File

@@ -369,6 +369,10 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.tensor_parallel and cfg.gradient_checkpointing:
raise ValueError(
"TensorParallelPreTrainedModel does not support gradient checkpointing"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -5,7 +5,6 @@ import math
import os
from typing import Optional, Tuple # noqa: F401
import accelerate
import bitsandbytes as bnb
import tensor_parallel as tp
import torch
@@ -31,6 +30,7 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_distributed
LOG = logging.getLogger("axolotl")
@@ -328,19 +328,14 @@ def load_model(
**model_kwargs,
)
elif cfg.tensor_parallel:
config = AutoConfig.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(
config=config,
trust_remote_code=cfg.trust_remote_code or False,
).half()
model = tp.TensorParallelPreTrainedModel(
model,
sharded=False,
torch_dtype=cfg.torch_dtype,
low_cpu_mem_usage=True,
offload_state_dict=True,
)
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
config = AutoConfig.from_pretrained(
base_model,
@@ -473,12 +468,17 @@ def load_model(
load_file = torch.load
try:
with open(
hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r"
hf_hub_download(base_model, "pytorch_model.bin.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
except:
except Exception as err: # pylint: disable=broad-exception-caught
LOG.warning(err)
with open(
hf_hub_download(base_model, "model.safetensors.index.json"), "r"
hf_hub_download(base_model, "model.safetensors.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
load_file = load_safetensors_file
@@ -492,7 +492,7 @@ def load_model(
tp.convert_state_dict( # <- tensor_parallel helper function.
load_file(
shard_path
), # Creates a tensor_parallel checkpoint form a normal one
), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=torch.cuda.device_count(),
for_pretrained=True,