diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 2f998d144..d84368bc0 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -1,6 +1,7 @@ base_model: google/gemma-3-1b-it model_type: Gemma3ForCausalLM +cls_model_config: Gemma3TextConfig # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name @@ -29,7 +30,7 @@ output_dir: ./outputs/out adapter: qlora lora_r: 32 lora_alpha: 16 -lora_dropout: 0.05 +lora_dropout: 0 lora_target_linear: true sequence_len: 2048 diff --git a/examples/gemma3/gemma-3-270m-qlora.yml b/examples/gemma3/gemma-3-270m-qlora.yml index 0c60c4a01..14ea2aaba 100644 --- a/examples/gemma3/gemma-3-270m-qlora.yml +++ b/examples/gemma3/gemma-3-270m-qlora.yml @@ -1,6 +1,7 @@ base_model: google/gemma-3-270m-it model_type: Gemma3ForCausalLM +cls_model_config: Gemma3TextConfig # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name @@ -29,7 +30,7 @@ output_dir: ./outputs/out adapter: qlora lora_r: 32 lora_alpha: 16 -lora_dropout: 0.05 +lora_dropout: 0 lora_target_linear: true sequence_len: 2048 diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 959521149..7d44f3c9b 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it # Need to set else transformers tries to load vision too model_type: Gemma3ForCausalLM +cls_model_config: Gemma3TextConfig load_in_4bit: true @@ -32,8 +33,8 @@ sample_packing: true lora_r: 32 lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_dropout: 0 +lora_target_linear: true wandb_project: wandb_entity: diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index b42b6b492..a12e84bee 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -31,7 +31,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 -lora_dropout: 0.05 +lora_dropout: 0 lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index b1902c9b5..187784b93 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -5,6 +5,7 @@ from typing import Type import addict import torch +import transformers from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from axolotl.utils.dict import DictDefault @@ -153,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: This function determines the appropriate model config source, loads it, applies any necessary overrides, and validates it for compatibility with the `axolotl` config. + If `cfg.cls_model_config` is set, a custom config class from transformers will be + used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig'). + Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -174,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: if cfg.num_labels: # num_labels is used to initialize classifier models config_kwargs["num_labels"] = cfg.num_labels + + config_cls = AutoConfig + if cfg.cls_model_config: + config_cls = getattr(transformers, cfg.cls_model_config) + try: - model_config = AutoConfig.from_pretrained( + model_config = config_cls.from_pretrained( model_config_name, trust_remote_code=trust_remote_code, **config_kwargs, diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 04312eedd..0931608a6 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -25,7 +25,12 @@ class ModelInputConfig(BaseModel): "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" }, ) - cls_model_config: str | None = None + cls_model_config: str | None = Field( + default=None, + json_schema_extra={ + "description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig." + }, + ) tokenizer_config: str | None = Field( default=None, json_schema_extra={