working?
This commit is contained in:
@@ -369,18 +369,25 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
model.tie_weights()
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
for name, module in model.named_children():
|
||||
if name == "experts":
|
||||
# torch.distributed.breakpoint()
|
||||
for expert in module.children():
|
||||
# torch.distributed.breakpoint()
|
||||
print(f"expert: {expert}")
|
||||
for lora_module in expert.children():
|
||||
print(f"lora {lora_module}")
|
||||
# torch.distributed.breakpoint()
|
||||
cast_lora_module(lora_module)
|
||||
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
for module in get_module_children_bottom_up(model)[:-1]:
|
||||
if is_peft_model and isinstance(module, LoraLayer):
|
||||
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
|
||||
# torch.distributed.breakpoint()
|
||||
cast_lora_module(module)
|
||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||
module, fsdp2_kwargs
|
||||
)
|
||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||
# torch.distributed.breakpoint()
|
||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||
fully_shard(module, **fsdp2_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user