fix for local variable 'LlamaForCausalLM' referenced before assignment

This commit is contained in:
Wing Lian
2023-06-10 14:11:13 -04:00
parent 215d775147
commit 14163c15d9

View File

@@ -90,6 +90,7 @@ def load_model(
Load a model from a base model and a model type.
"""
global LlamaForCausalLM # pylint: disable=global-statement
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = "llama" in base_model or (