This commit is contained in:
Salman Mohammadi
2025-09-12 17:34:41 +00:00
parent 6874d32e0c
commit 850489405b

View File

@@ -369,18 +369,25 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
model.tie_weights()
is_peft_model = isinstance(model, PeftModel)
for name, module in model.named_children():
if name == "experts":
# torch.distributed.breakpoint()
for expert in module.children():
# torch.distributed.breakpoint()
print(f"expert: {expert}")
for lora_module in expert.children():
print(f"lora {lora_module}")
# torch.distributed.breakpoint()
cast_lora_module(lora_module)
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
# torch.distributed.breakpoint()
cast_lora_module(module)
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
# torch.distributed.breakpoint()
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)