feat: add peft_trainable_token_indices (#3062)

* feat: add peft_trainable_token_indices

* feat: add warning compat with fix_untrained_tokens
This commit is contained in:
NanoCode012
2025-09-03 12:48:01 +07:00
committed by GitHub
parent 4cc6038d52
commit 53a0c1f39c
3 changed files with 41 additions and 1 deletions

View File

@@ -98,6 +98,8 @@ def load_lora(
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
lora_config = LoraConfig(
r=cfg.lora_r,

View File

@@ -947,7 +947,15 @@ class AxolotlInputConfig(
},
)
fix_untrained_tokens: int | list[int] | None = None
fix_untrained_tokens: int | list[int] | None = Field(
default=None,
json_schema_extra={
"description": (
"Token index or indices to adjust embedding weights to the mean of the other tokens. "
"This is useful when the model has untrained embeddings."
)
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
@@ -1006,6 +1014,26 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@model_validator(mode="before")
@classmethod
def warn_peft_trainable_token_to_fix_untrained(cls, data):
if (
peft_trainable_token_indices := data.get("peft_trainable_token_indices")
) and (fix_untrained_tokens := data.get("fix_untrained_tokens")):
if isinstance(fix_untrained_tokens, int):
fix_untrained_tokens = (fix_untrained_tokens,)
if isinstance(peft_trainable_token_indices, int):
peft_trainable_token_indices = (peft_trainable_token_indices,)
for untrained_token_id in fix_untrained_tokens:
if untrained_token_id not in peft_trainable_token_indices:
LOG.warning_once(
f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. "
"Please add it, otherwise the token won't be trained on."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate GPU capabilities with the configured options"""

View File

@@ -90,6 +90,16 @@ class LoraConfig(BaseModel):
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
},
)
peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field(
default=None,
json_schema_extra={
"description": (
"A list of token indices to fine-tune on the `embed_tokens` layer.\n"
"Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n"
"See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora"
)
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,