fix: gemma3_text model loading vision config (#3354)
* fix: gemma3-text mode loading vision config * fix: improve defaults to use lora kernels
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -32,8 +33,8 @@ sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Type
|
||||
|
||||
import addict
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -153,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
This function determines the appropriate model config source, loads it, applies any
|
||||
necessary overrides, and validates it for compatibility with the `axolotl` config.
|
||||
|
||||
If `cfg.cls_model_config` is set, a custom config class from transformers will be
|
||||
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -174,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
if cfg.num_labels:
|
||||
# num_labels is used to initialize classifier models
|
||||
config_kwargs["num_labels"] = cfg.num_labels
|
||||
|
||||
config_cls = AutoConfig
|
||||
if cfg.cls_model_config:
|
||||
config_cls = getattr(transformers, cfg.cls_model_config)
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config = config_cls.from_pretrained(
|
||||
model_config_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**config_kwargs,
|
||||
|
||||
@@ -25,7 +25,12 @@ class ModelInputConfig(BaseModel):
|
||||
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
|
||||
},
|
||||
)
|
||||
cls_model_config: str | None = None
|
||||
cls_model_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
|
||||
},
|
||||
)
|
||||
tokenizer_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user