Fix FSDP2 sharding and validate AO version for LR groups (#3403)
* Fix fsdp2 sharding. Fix validation of ao version for lr groups * remove validation since axolotl requires ao>0.13.0 already * Move fully_shard of entire module for lora_embedding_A/B out of loop * chore: lint --------- Co-authored-by: bekk02 <ID+bekk02@users.noreply.github.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -252,12 +252,20 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_B:
|
||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_A:
|
||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_B:
|
||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_magnitude_vector:
|
||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||
|
||||
# lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors),
|
||||
# not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard
|
||||
# individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not
|
||||
# override groups already assigned by fully_shard(), so modules
|
||||
# where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html]
|
||||
if module.lora_embedding_A or module.lora_embedding_B:
|
||||
from torch.distributed.fsdp import FSDPModule
|
||||
|
||||
if not isinstance(module, FSDPModule):
|
||||
fully_shard(module, **fsdp2_kwargs)
|
||||
|
||||
return log_bias_dtype_mismatch
|
||||
|
||||
|
||||
|
||||
@@ -986,23 +986,6 @@ class OptimizationValidationMixin:
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def lr_groups_ao_optimizer(self):
|
||||
if (
|
||||
self.loraplus_lr_ratio is not None
|
||||
or self.embedding_lr_scale is not None
|
||||
or self.embedding_lr is not None
|
||||
or self.lr_groups is not None
|
||||
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||
# TODO(wing): remove this once ao>0.12.0
|
||||
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
|
||||
raise ValueError(
|
||||
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
|
||||
"supported with ao low-bit optimizers until ao>0.12.0. "
|
||||
"Please refer to https://github.com/pytorch/ao/pull/2606."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user