diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ce85a47eb..a780dea01 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -109,7 +109,7 @@ def load_model( else: model = LlamaForCausalLM.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -117,14 +117,14 @@ def load_model( elif model_type: model = getattr(transformers, model_type).from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) else: model = AutoModelForCausalLM.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -135,7 +135,7 @@ def load_model( logging.exception(e) model = AutoModelForCausalLM.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, ) @@ -147,7 +147,7 @@ def load_model( else: tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) except: - tokenizer = AutoTokenizer.from_pretrained(base_model) + tokenizer = AutoTokenizer.from_pretrained(base_model_config) logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")