diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index 83678430a..37267e3f7 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -21,8 +21,10 @@ class TestModelsUtils: "base_model": "JackFram/llama-68m", "model_type": "LlamaForCausalLM", "tokenizer_type": "LlamaTokenizer", - "load_in_8bit": True, - "load_in_4bit": False, + "quantization": { + "backend": "bnb", + "bits": 8, + }, "adapter": "lora", "flash_attention": False, "sample_packing": True,