Feat: Add example for Mistral (#644)

* Feat: Add example for Mistral

* chore: turn off flash

* chore: add is_mistral_derived_model

* chore: update following PR
This commit is contained in:
NanoCode012
2023-09-28 20:15:00 +09:00
committed by GitHub
parent 383f88d7a7
commit eb41f76f92
3 changed files with 79 additions and 3 deletions

View File

@@ -82,7 +82,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
@@ -98,10 +98,23 @@ def normalize_config(cfg):
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
cfg.is_mistral_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"mistral",
]
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower()
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device)