wip more tp fixes
This commit is contained in:
@@ -371,7 +371,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer_kwargs, trainer_cls
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
def hook_post_create_trainer(self, trainer):
|
def hook_post_create_trainer(self, trainer):
|
||||||
# TODO
|
if self.cfg.tensor_parallel:
|
||||||
|
trainer.model = trainer.accelerator.prepare_model(
|
||||||
|
trainer.model, device_placement=True
|
||||||
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
|
|||||||
@@ -369,6 +369,10 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
||||||
|
raise ValueError(
|
||||||
|
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import tensor_parallel as tp
|
import tensor_parallel as tp
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from transformers import ( # noqa: F401
|
|||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import is_distributed
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -328,19 +328,14 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif cfg.tensor_parallel:
|
elif cfg.tensor_parallel:
|
||||||
config = AutoConfig.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
torch_dtype=cfg.torch_dtype,
|
||||||
)
|
low_cpu_mem_usage=True,
|
||||||
with accelerate.init_empty_weights():
|
offload_state_dict=True,
|
||||||
model = AutoModelForCausalLM.from_config(
|
|
||||||
config=config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
).half()
|
|
||||||
model = tp.TensorParallelPreTrainedModel(
|
|
||||||
model,
|
|
||||||
sharded=False,
|
|
||||||
)
|
)
|
||||||
|
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||||
|
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -473,12 +468,17 @@ def load_model(
|
|||||||
load_file = torch.load
|
load_file = torch.load
|
||||||
try:
|
try:
|
||||||
with open(
|
with open(
|
||||||
hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r"
|
hf_hub_download(base_model, "pytorch_model.bin.index.json"),
|
||||||
|
"r",
|
||||||
|
encoding="utf=8",
|
||||||
) as index_file:
|
) as index_file:
|
||||||
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
||||||
except:
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.warning(err)
|
||||||
with open(
|
with open(
|
||||||
hf_hub_download(base_model, "model.safetensors.index.json"), "r"
|
hf_hub_download(base_model, "model.safetensors.index.json"),
|
||||||
|
"r",
|
||||||
|
encoding="utf=8",
|
||||||
) as index_file:
|
) as index_file:
|
||||||
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
||||||
load_file = load_safetensors_file
|
load_file = load_safetensors_file
|
||||||
@@ -492,7 +492,7 @@ def load_model(
|
|||||||
tp.convert_state_dict( # <- tensor_parallel helper function.
|
tp.convert_state_dict( # <- tensor_parallel helper function.
|
||||||
load_file(
|
load_file(
|
||||||
shard_path
|
shard_path
|
||||||
), # Creates a tensor_parallel checkpoint form a normal one
|
), # Creates a tensor_parallel checkpoint form a normal one
|
||||||
model.tensor_parallel_config,
|
model.tensor_parallel_config,
|
||||||
world_size=torch.cuda.device_count(),
|
world_size=torch.cuda.device_count(),
|
||||||
for_pretrained=True,
|
for_pretrained=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user