diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3057d5da5..ee91d81ae 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -341,7 +341,12 @@ def load_model( config=config, trust_remote_code=cfg.trust_remote_code or False, low_cpu_mem_usage=True, - ).half() + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, + torch_dtype=cfg.torch_dtype, + device_map={"": "cpu"}, + **model_kwargs, + ) model = tp.tensor_parallel(model, sharded=False) else: config = AutoConfig.from_pretrained(