diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1acaf6ab3..fb363952c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -90,6 +90,7 @@ def load_model( Load a model from a base model and a model type. """ + global LlamaForCausalLM # pylint: disable=global-statement # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit cfg.is_llama_derived_model = "llama" in base_model or (