diff --git a/scripts/finetune.py b/scripts/finetune.py index ddf1992d6..4c40c5ef6 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -252,14 +252,7 @@ def train( # Load the model and tokenizer LOG.info("loading model and peft_config...") - model, peft_config = load_model( - cfg.base_model, - cfg.base_model_config, - cfg.model_type, - tokenizer, - cfg, - adapter=cfg.adapter, - ) + model, peft_config = load_model(cfg, tokenizer) if "merge_lora" in kwargs and cfg.adapter is not None: LOG.info("running merge of LoRA with base model") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7501878ba..83c648e46 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -77,12 +77,15 @@ def load_tokenizer( def load_model( - base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" -): - # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + cfg, tokenizer +): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ - Load a model from a base model and a model type. + Load a model for a given configuration and tokenizer. """ + base_model = cfg.base_model + base_model_config = cfg.base_model_config + model_type = cfg.model_type + adapter = cfg.adapter # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit