Merge pull request #356 from tmm1/load_model-args

simplify `load_model` signature
This commit is contained in:
Aman Gupta Karmani
2023-08-09 18:24:34 -07:00
committed by GitHub
2 changed files with 8 additions and 12 deletions

View File

@@ -255,14 +255,7 @@ def train(
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and peft_config...") LOG.info("loading model and peft_config...")
model, peft_config = load_model( model, peft_config = load_model(cfg, tokenizer)
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")

View File

@@ -78,12 +78,15 @@ def load_tokenizer(
def load_model( def load_model(
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" cfg, tokenizer
): ): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit