remove unnecessary local variable

This commit is contained in:
Aman Karmani
2023-08-13 01:58:39 +00:00
parent efb3b2c95e
commit 0c967279ce

View File

@@ -87,7 +87,6 @@ def load_model(
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
@@ -359,7 +358,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, adapter)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")