Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
68a3c7678a fix: parse model_kwargs (#1825) 2024-08-16 07:51:19 -04:00
NanoCode012
f18925fb4b fix: parse eager_attention (#1824) 2024-08-14 09:46:46 -04:00
2 changed files with 9 additions and 13 deletions

View File

@@ -321,6 +321,8 @@ class ModelInputConfig(BaseModel):
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -614,6 +616,8 @@ class AxolotlInputConfig(
flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None
eager_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None

View File

@@ -655,19 +655,11 @@ def load_model(
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0:
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (