From 323da791eb3f1c3c005058ba56cc272835339eaa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Apr 2026 09:27:03 -0400 Subject: [PATCH] bump transformers to 5.5.4 and trl to latest 1.1.0 (#3603) * bump transformers to 5.5.4 and trl to latest 1.1.0 * more upgrades * update peft too * adapt lora_merge to peft 0.19 layer config API PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv layer constructors and moved use_rslora, use_dora, fan_in_fan_out, lora_dropout, and lora_bias into that config. Build the config per branch in _build_peft_layer_and_get_delta so the merge utility works with the upgraded peft. * allow lora_dropout on mixed attention+MoE configs under peft 0.19 PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for transformers v5's fused 3D expert Parameters. Those targets get wrapped with ParamWrapper, which rejects lora_dropout != 0 because the 3D einsum can't factor dropout out of lora_B(lora_A(dropout(x))). Monkeypatch ParamWrapper.__init__ to internally use a copy of the LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity while the shared config still delivers real dropout to sibling Linear LoRA layers (attention q/k/v/o). A probe runs the same conversion on a deep copy to detect the situation and emit a warning before patching. --- requirements.txt | 12 +-- src/axolotl/cli/utils/lora_merge.py | 36 +++++++-- src/axolotl/loaders/adapter.py | 109 ++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 13 deletions(-) diff --git a/requirements.txt b/requirements.txt index fe4b67436..bb3fc8daa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,15 +10,15 @@ liger-kernel==0.7.0 packaging==26.0 huggingface_hub>=1.1.7 -peft>=0.18.1 +peft>=0.19.0,<0.20.0 tokenizers>=0.22.1 -transformers==5.5.3 +transformers==5.5.4 accelerate==1.13.0 -datasets==4.5.0 +datasets>=4.8.4,<4.9.0 deepspeed>=0.18.6,<0.19.0 -trl==0.29.0 -hf_xet==1.3.2 -kernels==0.12.2 +trl==1.1.0 +hf_xet==1.4.3 +kernels==0.13.0 fla-core==0.4.1 flash-linear-attention==0.4.1 diff --git a/src/axolotl/cli/utils/lora_merge.py b/src/axolotl/cli/utils/lora_merge.py index b7c436aed..a07395587 100644 --- a/src/axolotl/cli/utils/lora_merge.py +++ b/src/axolotl/cli/utils/lora_merge.py @@ -315,15 +315,27 @@ def _build_peft_layer_and_get_delta( "weight", nn.Parameter(base_tensor.clone(), requires_grad=False) ) + # ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so + # build a minimal config with only the fields it accepts. + pw_config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + lora_dropout=0.0, + fan_in_fan_out=False, + use_rslora=use_rslora, + use_dora=False, + lora_bias=False, + ) + with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) layer = ParamWrapper( fake, adapter_name=adapter_name, parameter_name="weight", + config=pw_config, r=r, lora_alpha=lora_alpha, - use_rslora=use_rslora, ) layer.lora_A[adapter_name].weight.data = lora_a layer.lora_B[adapter_name].weight.data = lora_b @@ -375,14 +387,19 @@ def _build_peft_layer_and_get_delta( ) base_layer.weight.data = base_tensor.clone() - layer = PeftConvCls( - base_layer, - adapter_name=adapter_name, + conv_config = LoraConfig( r=r_total, lora_alpha=lora_alpha, use_rslora=use_rslora, use_dora=use_dora, ) + layer = PeftConvCls( + base_layer, + adapter_name=adapter_name, + config=conv_config, + r=r_total, + lora_alpha=lora_alpha, + ) layer.lora_A[adapter_name].weight.data = lora_a layer.lora_B[adapter_name].weight.data = lora_b @@ -410,15 +427,20 @@ def _build_peft_layer_and_get_delta( or lora_config_dict.get("lora_fan_in_fan_out", False) ) - layer = LoraLinear( - base_layer, - adapter_name=adapter_name, + linear_config = LoraConfig( r=r_total, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out, use_rslora=use_rslora, use_dora=use_dora, ) + layer = LoraLinear( + base_layer, + adapter_name=adapter_name, + config=linear_config, + r=r_total, + lora_alpha=lora_alpha, + ) layer.lora_A[adapter_name].weight.data = lora_a layer.lora_B[adapter_name].weight.data = lora_b diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 6d0bd0fe1..3d662c0bb 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -124,6 +124,101 @@ def _patch_peft_clippable_linear(): LoraModel._axolotl_clippable_patched = True +def _peft_will_auto_convert_target_params(model, lora_config) -> bool: + """Check whether PEFT will auto-populate target_parameters for this model. + + PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE + ``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into + ``target_parameters`` (``gate_up_proj``/``down_proj``) because + transformers v5 fused those expert linears into 3D ``nn.Parameter`` + tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``, + which rejects ``lora_dropout != 0``. This probe runs the conversion on + a copy of the config so we can detect the situation before + ``get_peft_model`` blows up. + """ + if getattr(lora_config, "target_parameters", None): + return False + + try: + from peft.utils.transformers_weight_conversion import ( + convert_peft_config_for_transformers, + get_model_conversion_mapping, + ) + except ImportError: + return False + + import copy + + probe_cfg = copy.deepcopy(lora_config) + try: + convert_peft_config_for_transformers( + probe_cfg, + model=model, + conversions=get_model_conversion_mapping(model), + ) + except Exception: # pylint: disable=broad-except + return False + + return bool(getattr(probe_cfg, "target_parameters", None)) + + +def _patch_peft_param_wrapper_dropout(): + """Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``. + + ``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects + non-zero dropout because dropout can't be factored out of + ``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed + matmul. For mixed configs (attention + MoE experts) this is too + aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout + and that's usually what the user intended. We pass a copy of the + ``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__`` + so it builds with ``nn.Identity`` for its internal dropout slot while + every other layer type still receives the real dropout value. + """ + from peft.tuners.lora.layer import ParamWrapper + + if getattr(ParamWrapper, "_axolotl_dropout_patched", False): + return + + _orig_init = ParamWrapper.__init__ + + def _patched_init( + self, + base_layer, + adapter_name, + parameter_name, + config, + *args, + **kwargs, + ): + if getattr(config, "lora_dropout", 0): + import copy as _copy + + patched_config = _copy.copy(config) + patched_config.lora_dropout = 0.0 + return _orig_init( + self, + base_layer, + adapter_name, + parameter_name, + patched_config, + *args, + **kwargs, + ) + return _orig_init( + self, + base_layer, + adapter_name, + parameter_name, + config, + *args, + **kwargs, + ) + + ParamWrapper.__init__ = _patched_init + ParamWrapper._axolotl_dropout_patched = True + + def load_lora( model: PreTrainedModel, cfg: DictDefault, @@ -191,6 +286,20 @@ def load_lora( if config_only: return None, lora_config + if getattr( + lora_config, "lora_dropout", 0 + ) and _peft_will_auto_convert_target_params(model, lora_config): + LOG.warning( + "lora_dropout=%s requested but PEFT will wrap this model's fused " + "MoE expert parameters with ParamWrapper, which cannot apply " + "dropout (the 3D einsum can't factor dropout out of " + "lora_B(lora_A(dropout(x)))). Dropout will still be applied to " + "non-expert LoRA layers (e.g. attention), and expert LoRA layers " + "will use nn.Identity for the dropout slot.", + lora_config.lora_dropout, + ) + _patch_peft_param_wrapper_dropout() + rank = int(os.environ.get("LOCAL_RANK", 0)) if (