From 8e2a102ccabbc65e33d05a61385b2d9bc3c0dfcd Mon Sep 17 00:00:00 2001 From: bekk02 Date: Thu, 5 Mar 2026 06:59:32 -0800 Subject: [PATCH] Fix FSDP2 sharding and validate AO version for LR groups (#3403) * Fix fsdp2 sharding. Fix validation of ao version for lr groups * remove validation since axolotl requires ao>0.13.0 already * Move fully_shard of entire module for lora_embedding_A/B out of loop * chore: lint --------- Co-authored-by: bekk02 Co-authored-by: Wing Lian --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 16 ++++++++++++---- src/axolotl/utils/schemas/validation.py | 17 ----------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 4a8d9840f..dd3deb19a 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -252,12 +252,20 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs): fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) if module.lora_B: fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_A: - fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_B: - fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) if module.lora_magnitude_vector: fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) + + # lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors), + # not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard + # individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not + # override groups already assigned by fully_shard(), so modules + # where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html] + if module.lora_embedding_A or module.lora_embedding_B: + from torch.distributed.fsdp import FSDPModule + + if not isinstance(module, FSDPModule): + fully_shard(module, **fsdp2_kwargs) + return log_bias_dtype_mismatch diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 783017405..2ff57558f 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -986,23 +986,6 @@ class OptimizationValidationMixin: return self - @model_validator(mode="after") - def lr_groups_ao_optimizer(self): - if ( - self.loraplus_lr_ratio is not None - or self.embedding_lr_scale is not None - or self.embedding_lr is not None - or self.lr_groups is not None - ) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]: - # TODO(wing): remove this once ao>0.12.0 - # requires https://github.com/pytorch/ao/pull/2606 in an ao release - raise ValueError( - "lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not " - "supported with ao low-bit optimizers until ao>0.12.0. " - "Please refer to https://github.com/pytorch/ao/pull/2606." - ) - return self - @model_validator(mode="before") @classmethod def check_tensor_parallel_size_update_ds_json(cls, data):