remove unnecessary local variable
This commit is contained in:
@@ -87,7 +87,6 @@ def load_model(
|
|||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
base_model_config = cfg.base_model_config
|
base_model_config = cfg.base_model_config
|
||||||
model_type = cfg.model_type
|
model_type = cfg.model_type
|
||||||
adapter = cfg.adapter
|
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
@@ -359,7 +358,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch_dtype)
|
module.to(torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|||||||
Reference in New Issue
Block a user