remove unnecessary local variable

This commit is contained in:
Aman Karmani
2023-08-13 01:58:39 +00:00
parent efb3b2c95e
commit 0c967279ce

View File

@@ -87,7 +87,6 @@ def load_model(
base_model = cfg.base_model base_model = cfg.base_model
base_model_config = cfg.base_model_config base_model_config = cfg.base_model_config
model_type = cfg.model_type model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
@@ -359,7 +358,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch_dtype) module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")