wip tp
This commit is contained in:
@@ -4,7 +4,6 @@ import math
|
||||
import os
|
||||
from typing import Optional, Tuple # noqa: F401
|
||||
|
||||
import accelerate
|
||||
import bitsandbytes as bnb
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
@@ -328,14 +327,21 @@ def load_model(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
).half()
|
||||
model = tp.TensorParallelPreTrainedModel(
|
||||
model,
|
||||
)
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = AutoModelForCausalLM.from_config(
|
||||
# config=config,
|
||||
# trust_remote_code=cfg.trust_remote_code or False,
|
||||
# ).half()
|
||||
# model = tp.TensorParallelPreTrainedModel(
|
||||
# model,
|
||||
# sharded=False,
|
||||
# )
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
).half()
|
||||
model = tp.tensor_parallel(model, sharded=False)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
@@ -386,15 +392,18 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
try:
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
except NotImplementedError:
|
||||
LOG.warning("`resize_token_embeddings` not implemented on model")
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
@@ -497,7 +506,10 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
try:
|
||||
model.enable_input_require_grads()
|
||||
except NotImplementedError:
|
||||
LOG.warning("enable_input_require_grads not implemented on model")
|
||||
if adapter == "qlora" and cfg.tensor_parallel:
|
||||
return load_tp_qlora(model)
|
||||
if adapter in ["lora", "qlora"]:
|
||||
|
||||
Reference in New Issue
Block a user