"""Module for testing models utils file.""" from unittest.mock import MagicMock, patch import pytest from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils.import_utils import is_torch_mps_available from axolotl.utils.dict import DictDefault from axolotl.utils.models import ModelLoader, load_model class TestModelsUtils: """Testing module for models utils.""" def setup_method(self) -> None: # load config self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init { "base_model": "JackFram/llama-68m", "model_type": "LlamaForCausalLM", "tokenizer_type": "LlamaTokenizer", "load_in_8bit": True, "load_in_4bit": False, "adapter": "lora", "flash_attention": False, "sample_packing": True, "device_map": "auto", } ) self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init spec=PreTrainedTokenizerBase ) self.inference = False # pylint: disable=attribute-defined-outside-init self.reference_model = True # pylint: disable=attribute-defined-outside-init # init ModelLoader self.model_loader = ( # pylint: disable=attribute-defined-outside-init ModelLoader( cfg=self.cfg, tokenizer=self.tokenizer, inference=self.inference, reference_model=self.reference_model, ) ) def test_set_device_map_config(self): # check device_map device_map = self.cfg.device_map if is_torch_mps_available(): device_map = "mps" self.model_loader.set_device_map_config() if is_deepspeed_zero3_enabled(): assert "device_map" not in self.model_loader.model_kwargs else: assert device_map in self.model_loader.model_kwargs["device_map"] # check torch_dtype assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): cfg = DictDefault( { "s2_attention": True, "sample_packing": True, "base_model": "", "model_type": "LlamaForCausalLM", } ) # Mock out call to HF hub with patch( "axolotl.utils.models.load_model_config" ) as mocked_load_model_config: mocked_load_model_config.return_value = {} with pytest.raises(ValueError) as exc: # Should error before hitting tokenizer, so we pass in an empty str load_model(cfg, tokenizer="") assert ( "shifted-sparse attention does not currently support sample packing" in str(exc.value) ) @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) @pytest.mark.parametrize("load_in_8bit", [True, False]) @pytest.mark.parametrize("load_in_4bit", [True, False]) @pytest.mark.parametrize("gptq", [True, False]) def test_set_quantization_config( self, adapter, load_in_8bit, load_in_4bit, gptq, ): # init cfg as args self.cfg.load_in_8bit = load_in_8bit self.cfg.load_in_4bit = load_in_4bit self.cfg.gptq = gptq self.cfg.adapter = adapter self.model_loader.set_quantization_config() if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: assert not ( hasattr(self.model_loader.model_kwargs, "load_in_8bit") and hasattr(self.model_loader.model_kwargs, "load_in_4bit") ) elif load_in_8bit and self.cfg.adapter is not None: assert self.model_loader.model_kwargs["load_in_8bit"] elif load_in_4bit and self.cfg.adapter is not None: assert self.model_loader.model_kwargs["load_in_4bit"] if (self.cfg.adapter == "qlora" and load_in_4bit) or ( self.cfg.adapter == "lora" and load_in_8bit ): assert self.model_loader.model_kwargs.get( "quantization_config", BitsAndBytesConfig )