diff --git a/scripts/finetune.py b/scripts/finetune.py index 47623a518..53d6d8557 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -255,14 +255,7 @@ def train( # Load the model and tokenizer LOG.info("loading model and peft_config...") - model, peft_config = load_model( - cfg.base_model, - cfg.base_model_config, - cfg.model_type, - tokenizer, - cfg, - adapter=cfg.adapter, - ) + model, peft_config = load_model(cfg, tokenizer) if "merge_lora" in kwargs and cfg.adapter is not None: LOG.info("running merge of LoRA with base model") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 31e211953..9224d0f4d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -78,12 +78,15 @@ def load_tokenizer( def load_model( - base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" -): - # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + cfg, tokenizer +): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ - Load a model from a base model and a model type. + Load a model for a given configuration and tokenizer. """ + base_model = cfg.base_model + base_model_config = cfg.base_model_config + model_type = cfg.model_type + adapter = cfg.adapter # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit