simplify linear layer locator
This commit is contained in:
committed by
Aman Gupta Karmani
parent
98bf76e236
commit
267b7b24e5
@@ -464,12 +464,8 @@ def load_llama_adapter(model, cfg):
|
|||||||
return model, peft_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(bits, model):
|
def find_all_linear_names(model):
|
||||||
cls = (
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
bnb.nn.Linear4bit
|
|
||||||
if bits == 4
|
|
||||||
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
|
||||||
)
|
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, cls):
|
if isinstance(module, cls):
|
||||||
@@ -490,13 +486,7 @@ def load_lora(model, cfg):
|
|||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
bits = None
|
linear_names = find_all_linear_names(model)
|
||||||
if cfg.load_in_4bit:
|
|
||||||
bits = 4
|
|
||||||
elif cfg.load_in_8bit:
|
|
||||||
bits = 8
|
|
||||||
|
|
||||||
linear_names = find_all_linear_names(bits, model)
|
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user