working?
This commit is contained in:
@@ -369,18 +369,25 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if name == "experts":
|
||||||
|
# torch.distributed.breakpoint()
|
||||||
|
for expert in module.children():
|
||||||
|
# torch.distributed.breakpoint()
|
||||||
|
print(f"expert: {expert}")
|
||||||
|
for lora_module in expert.children():
|
||||||
|
print(f"lora {lora_module}")
|
||||||
|
# torch.distributed.breakpoint()
|
||||||
|
cast_lora_module(lora_module)
|
||||||
|
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
for module in get_module_children_bottom_up(model)[:-1]:
|
for module in get_module_children_bottom_up(model)[:-1]:
|
||||||
if is_peft_model and isinstance(module, LoraLayer):
|
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
|
||||||
# torch.distributed.breakpoint()
|
# torch.distributed.breakpoint()
|
||||||
cast_lora_module(module)
|
cast_lora_module(module)
|
||||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
# torch.distributed.breakpoint()
|
||||||
module, fsdp2_kwargs
|
|
||||||
)
|
|
||||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
|
||||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||||
fully_shard(module, **fsdp2_kwargs)
|
fully_shard(module, **fsdp2_kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user