feat: add falcon-h1 into axolotl (#2811) [skip ci]
* feat: add falcon-h1 into axolotl * fix pre-commit * review * fix: remove packing
This commit is contained in:
@@ -504,6 +504,9 @@ class ModelLoader:
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
|
||||
if self.cfg.bnb_config_kwargs:
|
||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||
@@ -518,6 +521,9 @@ class ModelLoader:
|
||||
# Exclude mamba blocks from int8 quantization for jamba
|
||||
if self.cfg.model_config_type == "jamba":
|
||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -54,6 +54,7 @@ class ChatTemplate(str, Enum):
|
||||
jinja = "jinja"
|
||||
qwen_25 = "qwen_25"
|
||||
qwen3 = "qwen3"
|
||||
falcon_h1 = "falcon_h1"
|
||||
tokenizer_default = "tokenizer_default"
|
||||
exaone = "exaone"
|
||||
metharme = "metharme"
|
||||
|
||||
Reference in New Issue
Block a user