Merge pull request #356 from tmm1/load_model-args
simplify `load_model` signature
This commit is contained in:
@@ -255,14 +255,7 @@ def train(
|
|||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and peft_config...")
|
LOG.info("loading model and peft_config...")
|
||||||
model, peft_config = load_model(
|
model, peft_config = load_model(cfg, tokenizer)
|
||||||
cfg.base_model,
|
|
||||||
cfg.base_model_config,
|
|
||||||
cfg.model_type,
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
adapter=cfg.adapter,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
|
|||||||
@@ -78,12 +78,15 @@ def load_tokenizer(
|
|||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
cfg, tokenizer
|
||||||
):
|
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
|
||||||
"""
|
"""
|
||||||
Load a model from a base model and a model type.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
|
base_model = cfg.base_model
|
||||||
|
base_model_config = cfg.base_model_config
|
||||||
|
model_type = cfg.model_type
|
||||||
|
adapter = cfg.adapter
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
|
|||||||
Reference in New Issue
Block a user