Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff * remove unused * add back needed import * fix
This commit is contained in:
@@ -17,7 +17,7 @@ class TestModelsUtils:
|
||||
|
||||
def setup_method(self) -> None:
|
||||
# load config
|
||||
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
||||
self.cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
@@ -30,20 +30,16 @@ class TestModelsUtils:
|
||||
"device_map": "auto",
|
||||
}
|
||||
)
|
||||
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
|
||||
spec=PreTrainedTokenizerBase
|
||||
)
|
||||
self.inference = False # pylint: disable=attribute-defined-outside-init
|
||||
self.reference_model = True # pylint: disable=attribute-defined-outside-init
|
||||
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
self.inference = False
|
||||
self.reference_model = True
|
||||
|
||||
# init ModelLoader
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
inference=self.inference,
|
||||
reference_model=self.reference_model,
|
||||
)
|
||||
self.model_loader = ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
inference=self.inference,
|
||||
reference_model=self.reference_model,
|
||||
)
|
||||
|
||||
def test_set_device_map_config(self):
|
||||
@@ -51,7 +47,7 @@ class TestModelsUtils:
|
||||
device_map = self.cfg.device_map
|
||||
if is_torch_mps_available():
|
||||
device_map = "mps"
|
||||
# pylint: disable=protected-access
|
||||
|
||||
self.model_loader._set_device_map_config()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert "device_map" not in self.model_loader.model_kwargs
|
||||
@@ -78,7 +74,6 @@ class TestModelsUtils:
|
||||
self.cfg.gptq = gptq
|
||||
self.cfg.adapter = adapter
|
||||
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_quantization_config()
|
||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||
assert not (
|
||||
@@ -194,7 +189,7 @@ class TestModelsUtils:
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
res = _get_parallel_config_kwargs(
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
|
||||
Reference in New Issue
Block a user