feat: add peft_trainable_token_indices (#3062)
* feat: add peft_trainable_token_indices * feat: add warning compat with fix_untrained_tokens
This commit is contained in:
@@ -98,6 +98,8 @@ def load_lora(
|
|||||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||||
if cfg.peft_layer_replication:
|
if cfg.peft_layer_replication:
|
||||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||||
|
if cfg.peft_trainable_token_indices:
|
||||||
|
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
|
|||||||
@@ -947,7 +947,15 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
fix_untrained_tokens: int | list[int] | None = None
|
fix_untrained_tokens: int | list[int] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Token index or indices to adjust embedding weights to the mean of the other tokens. "
|
||||||
|
"This is useful when the model has untrained embeddings."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: bool | None = None
|
is_preprocess: bool | None = None
|
||||||
@@ -1006,6 +1014,26 @@ class AxolotlInputConfig(
|
|||||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def warn_peft_trainable_token_to_fix_untrained(cls, data):
|
||||||
|
if (
|
||||||
|
peft_trainable_token_indices := data.get("peft_trainable_token_indices")
|
||||||
|
) and (fix_untrained_tokens := data.get("fix_untrained_tokens")):
|
||||||
|
if isinstance(fix_untrained_tokens, int):
|
||||||
|
fix_untrained_tokens = (fix_untrained_tokens,)
|
||||||
|
|
||||||
|
if isinstance(peft_trainable_token_indices, int):
|
||||||
|
peft_trainable_token_indices = (peft_trainable_token_indices,)
|
||||||
|
|
||||||
|
for untrained_token_id in fix_untrained_tokens:
|
||||||
|
if untrained_token_id not in peft_trainable_token_indices:
|
||||||
|
LOG.warning_once(
|
||||||
|
f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. "
|
||||||
|
"Please add it, otherwise the token won't be trained on."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate GPU capabilities with the configured options"""
|
"""wrapper to valdiate GPU capabilities with the configured options"""
|
||||||
|
|||||||
@@ -90,6 +90,16 @@ class LoraConfig(BaseModel):
|
|||||||
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
|
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"A list of token indices to fine-tune on the `embed_tokens` layer.\n"
|
||||||
|
"Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n"
|
||||||
|
"See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user