diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 867e6901c..989b34aee 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -98,6 +98,8 @@ def load_lora( lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication + if cfg.peft_trainable_token_indices: + lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices lora_config = LoraConfig( r=cfg.lora_r, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d43c346cd..1d2ddf4ae 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -947,7 +947,15 @@ class AxolotlInputConfig( }, ) - fix_untrained_tokens: int | list[int] | None = None + fix_untrained_tokens: int | list[int] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Token index or indices to adjust embedding weights to the mean of the other tokens. " + "This is useful when the model has untrained embeddings." + ) + }, + ) # INTERNALS - document for now, generally not set externally is_preprocess: bool | None = None @@ -1006,6 +1014,26 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None + @model_validator(mode="before") + @classmethod + def warn_peft_trainable_token_to_fix_untrained(cls, data): + if ( + peft_trainable_token_indices := data.get("peft_trainable_token_indices") + ) and (fix_untrained_tokens := data.get("fix_untrained_tokens")): + if isinstance(fix_untrained_tokens, int): + fix_untrained_tokens = (fix_untrained_tokens,) + + if isinstance(peft_trainable_token_indices, int): + peft_trainable_token_indices = (peft_trainable_token_indices,) + + for untrained_token_id in fix_untrained_tokens: + if untrained_token_id not in peft_trainable_token_indices: + LOG.warning_once( + f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. " + "Please add it, otherwise the token won't be trained on." + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate GPU capabilities with the configured options""" diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index de29521cb..af22913fd 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -90,6 +90,16 @@ class LoraConfig(BaseModel): "description": "How to initialize LoRA weights. Default to True which is MS original implementation." }, ) + peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "A list of token indices to fine-tune on the `embed_tokens` layer.\n" + "Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n" + "See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora" + ) + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False,