Add ruff, remove black, isort, flake8, pylint (#3092)

* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
This commit is contained in:
Dan Saunders
2025-08-23 23:37:33 -04:00
committed by GitHub
parent eea7a006e1
commit 79ddaebe9a
286 changed files with 10979 additions and 11435 deletions

View File

@@ -17,7 +17,7 @@ class TestModelsUtils:
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
self.cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
@@ -30,20 +30,16 @@ class TestModelsUtils:
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
self.inference = False
self.reference_model = True
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
self.model_loader = ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
def test_set_device_map_config(self):
@@ -51,7 +47,7 @@ class TestModelsUtils:
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
# pylint: disable=protected-access
self.model_loader._set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
@@ -78,7 +74,6 @@ class TestModelsUtils:
self.cfg.gptq = gptq
self.cfg.adapter = adapter
# pylint: disable=protected-access
self.model_loader._set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
@@ -194,7 +189,7 @@ class TestModelsUtils:
is_fsdp,
expected,
):
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
res = _get_parallel_config_kwargs(
world_size,
tensor_parallel_size,
context_parallel_size,