fix: gemma3_text model loading vision config (#3354)

* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
This commit is contained in:
NanoCode012
2026-01-13 21:49:23 +07:00
committed by GitHub
parent 258ce8d4fa
commit 359b7ad85e
6 changed files with 24 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -5,6 +5,7 @@ from typing import Type
import addict
import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
@@ -153,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -174,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try:
model_config = AutoConfig.from_pretrained(
model_config = config_cls.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,

View File

@@ -25,7 +25,12 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
},
)
cls_model_config: str | None = None
cls_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
tokenizer_config: str | None = Field(
default=None,
json_schema_extra={