fix: gemma3_text model loading vision config (#3354)

* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
This commit is contained in:
NanoCode012
2026-01-13 21:49:23 +07:00
committed by GitHub
parent 258ce8d4fa
commit 359b7ad85e
6 changed files with 24 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too # Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_linear: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:

View File

@@ -5,6 +5,7 @@ from typing import Type
import addict import addict
import torch import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -153,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config. necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -174,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels: if cfg.num_labels:
# num_labels is used to initialize classifier models # num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try: try:
model_config = AutoConfig.from_pretrained( model_config = config_cls.from_pretrained(
model_config_name, model_config_name,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**config_kwargs, **config_kwargs,

View File

@@ -25,7 +25,12 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
}, },
) )
cls_model_config: str | None = None cls_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
tokenizer_config: str | None = Field( tokenizer_config: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={