From 526e5ee8b8d2b8edc65a1bd314965bffa101c132 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 8 Feb 2025 18:01:48 +0700 Subject: [PATCH] fix(config): missing config not being documented and fix model_ override (#2317) * fix(config): missing config not being documented and fix model_ space override * fix: delete redundant field --- docs/config.qmd | 4 ++++ src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 5 +++-- src/axolotl/utils/models.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index ecb571040..91744856f 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -46,6 +46,10 @@ overrides_of_model_config: type: # linear | dynamic factor: # float +# optional overrides the base model loading from_pretrained +overrides_of_model_kwargs: + # use_cache: False + # optional overrides to the bnb 4bit quantization configuration # https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig bnb_config_kwargs: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 028b7ea18..aa79c0f61 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -115,6 +115,9 @@ class RemappedParameters(BaseModel): overrides_of_model_config: Optional[Dict[str, Any]] = Field( default=None, alias="model_config" ) + overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field( + default=None, alias="model_kwargs" + ) type_of_model: Optional[str] = Field(default=None, alias="model_type") revision_of_model: Optional[str] = Field(default=None, alias="model_revision") @@ -426,8 +429,6 @@ class ModelInputConfig(BaseModel): ) trust_remote_code: Optional[bool] = None - model_kwargs: Optional[Dict[str, Any]] = None - @field_validator("trust_remote_code") @classmethod def hint_trust_remote_code(cls, trust_remote_code): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d46564f42..be5b2782a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -357,8 +357,8 @@ class ModelLoader: # init model kwargs self.model_kwargs: Dict[str, Any] = {} - if cfg.model_kwargs: - for key, val in cfg.model_kwargs.items(): + if cfg.overrides_of_model_kwargs: + for key, val in cfg.overrides_of_model_kwargs.items(): self.model_kwargs[key] = val # init model