diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 68306a689..c5429e05e 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -369,18 +369,25 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: model.tie_weights() is_peft_model = isinstance(model, PeftModel) - + for name, module in model.named_children(): + if name == "experts": + # torch.distributed.breakpoint() + for expert in module.children(): + # torch.distributed.breakpoint() + print(f"expert: {expert}") + for lora_module in expert.children(): + print(f"lora {lora_module}") + # torch.distributed.breakpoint() + cast_lora_module(lora_module) + _process_lora_module_for_fsdp(lora_module, fsdp2_kwargs) auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) log_bias_dtype_mismatch = False if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: - if is_peft_model and isinstance(module, LoraLayer): + if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule): # torch.distributed.breakpoint() cast_lora_module(module) - module_log_bias_mismatch = _process_lora_module_for_fsdp( - module, fsdp2_kwargs - ) - log_bias_dtype_mismatch |= module_log_bias_mismatch + # torch.distributed.breakpoint() if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs)