Fix FSDP2 sharding and validate AO version for LR groups (#3403)

* Fix fsdp2 sharding. Fix validation of ao version for lr groups

* remove validation since axolotl requires ao>0.13.0 already

* Move fully_shard of entire module for lora_embedding_A/B out of loop

* chore: lint

---------

Co-authored-by: bekk02 <ID+bekk02@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
bekk02
2026-03-05 06:59:32 -08:00
committed by GitHub
parent 753906cfc7
commit 8e2a102cca
2 changed files with 12 additions and 21 deletions

View File

@@ -252,12 +252,20 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
# lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors),
# not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard
# individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not
# override groups already assigned by fully_shard(), so modules
# where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html]
if module.lora_embedding_A or module.lora_embedding_B:
from torch.distributed.fsdp import FSDPModule
if not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
return log_bias_dtype_mismatch

View File

@@ -986,23 +986,6 @@ class OptimizationValidationMixin:
return self
@model_validator(mode="after")
def lr_groups_ao_optimizer(self):
if (
self.loraplus_lr_ratio is not None
or self.embedding_lr_scale is not None
or self.embedding_lr is not None
or self.lr_groups is not None
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
# TODO(wing): remove this once ao>0.12.0
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
raise ValueError(
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
"supported with ao low-bit optimizers until ao>0.12.0. "
"Please refer to https://github.com/pytorch/ao/pull/2606."
)
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):