feat: add peft_trainable_token_indices (#3062)
* feat: add peft_trainable_token_indices * feat: add warning compat with fix_untrained_tokens
This commit is contained in:
@@ -98,6 +98,8 @@ def load_lora(
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
|
||||
@@ -947,7 +947,15 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
fix_untrained_tokens: int | list[int] | None = None
|
||||
fix_untrained_tokens: int | list[int] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Token index or indices to adjust embedding weights to the mean of the other tokens. "
|
||||
"This is useful when the model has untrained embeddings."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
is_preprocess: bool | None = None
|
||||
@@ -1006,6 +1014,26 @@ class AxolotlInputConfig(
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_peft_trainable_token_to_fix_untrained(cls, data):
|
||||
if (
|
||||
peft_trainable_token_indices := data.get("peft_trainable_token_indices")
|
||||
) and (fix_untrained_tokens := data.get("fix_untrained_tokens")):
|
||||
if isinstance(fix_untrained_tokens, int):
|
||||
fix_untrained_tokens = (fix_untrained_tokens,)
|
||||
|
||||
if isinstance(peft_trainable_token_indices, int):
|
||||
peft_trainable_token_indices = (peft_trainable_token_indices,)
|
||||
|
||||
for untrained_token_id in fix_untrained_tokens:
|
||||
if untrained_token_id not in peft_trainable_token_indices:
|
||||
LOG.warning_once(
|
||||
f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. "
|
||||
"Please add it, otherwise the token won't be trained on."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate GPU capabilities with the configured options"""
|
||||
|
||||
@@ -90,6 +90,16 @@ class LoraConfig(BaseModel):
|
||||
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
|
||||
},
|
||||
)
|
||||
peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"A list of token indices to fine-tune on the `embed_tokens` layer.\n"
|
||||
"Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n"
|
||||
"See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
qlora_sharded_model_loading: bool | None = Field(
|
||||
default=False,
|
||||
|
||||
Reference in New Issue
Block a user