From c907ac173e586e759da9c71aa2a51e294648d1f0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 21 Mar 2025 11:02:43 -0400 Subject: [PATCH] adding pre-commit auto-update GH action and bumping plugin versions (#2428) * adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian --------- Co-authored-by: Dan Saunders Co-authored-by: Wing Lian --- .github/workflows/precommit-autoupdate.yml | 49 ++++ .pre-commit-config.yaml | 14 +- cicd/multigpu.py | 5 +- cicd/tests.py | 1 + scripts/chat_datasets.py | 1 + scripts/cutcrossentropy_install.py | 7 +- src/axolotl/cli/cloud/__init__.py | 1 + src/axolotl/cli/cloud/base.py | 1 + src/axolotl/cli/cloud/modal_.py | 1 + src/axolotl/cli/main.py | 1 + src/axolotl/cli/utils.py | 3 +- src/axolotl/convert.py | 1 - src/axolotl/core/chat/format/chatml.py | 1 + src/axolotl/core/chat/format/llama3x.py | 1 + src/axolotl/core/chat/format/shared.py | 1 + src/axolotl/core/chat/messages.py | 1 + src/axolotl/core/datasets/chat.py | 1 + .../core/datasets/transforms/chat_builder.py | 1 + src/axolotl/core/trainer_builder.py | 252 +++++++++--------- src/axolotl/core/trainers/base.py | 12 +- src/axolotl/core/trainers/dpo/__init__.py | 1 + src/axolotl/core/trainers/dpo/args.py | 1 + src/axolotl/core/trainers/dpo/trainer.py | 1 + src/axolotl/core/trainers/grpo/__init__.py | 12 +- src/axolotl/core/trainers/grpo/args.py | 1 + src/axolotl/core/trainers/grpo/trainer.py | 1 + src/axolotl/core/trainers/trl.py | 1 + src/axolotl/core/training_args.py | 1 + src/axolotl/integrations/grokfast/__init__.py | 1 + src/axolotl/integrations/grokfast/args.py | 1 + src/axolotl/integrations/kd/args.py | 12 +- src/axolotl/integrations/liger/__init__.py | 6 +- .../integrations/liger/models/deepseekv2.py | 1 + .../integrations/liger/models/jamba.py | 1 + src/axolotl/integrations/lm_eval/__init__.py | 1 + src/axolotl/integrations/lm_eval/args.py | 1 + src/axolotl/integrations/lm_eval/cli.py | 1 + src/axolotl/kernels/geglu.py | 1 + src/axolotl/kernels/lora.py | 1 + src/axolotl/kernels/quantize.py | 1 + src/axolotl/kernels/swiglu.py | 1 + .../models/mamba/configuration_mamba.py | 1 + src/axolotl/monkeypatch/attention/mllama.py | 13 +- .../monkeypatch/data/batch_dataset_fetcher.py | 1 + .../monkeypatch/llama_attn_hijack_flash.py | 20 +- src/axolotl/monkeypatch/llama_expand_mask.py | 1 + .../monkeypatch/mistral_attn_hijack_flash.py | 22 +- src/axolotl/monkeypatch/mixtral/__init__.py | 1 + src/axolotl/monkeypatch/relora.py | 6 +- .../monkeypatch/stablelm_attn_hijack_flash.py | 2 +- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 1 + src/axolotl/monkeypatch/utils.py | 1 + src/axolotl/monkeypatch/xformers_/__init__.py | 1 + .../prompt_strategies/alpaca_w_system.py | 1 + src/axolotl/prompt_strategies/completion.py | 1 + src/axolotl/prompt_strategies/context_qa.py | 1 + src/axolotl/prompt_strategies/dpo/__init__.py | 1 + src/axolotl/prompt_strategies/dpo/chatml.py | 36 +-- src/axolotl/prompt_strategies/dpo/llama3.py | 36 +-- src/axolotl/prompt_strategies/input_output.py | 1 + .../jinja_template_analyzer.py | 1 + src/axolotl/prompt_strategies/kto/chatml.py | 31 +-- src/axolotl/prompt_strategies/kto/llama3.py | 31 +-- .../prompt_strategies/kto/user_defined.py | 1 + .../prompt_strategies/messages/chat.py | 1 + src/axolotl/prompt_strategies/orcamini.py | 1 + .../prompt_strategies/orpo/chat_template.py | 1 + src/axolotl/prompt_strategies/pretrain.py | 1 + src/axolotl/train.py | 4 +- src/axolotl/utils/__init__.py | 6 +- src/axolotl/utils/bench.py | 1 + src/axolotl/utils/callbacks/__init__.py | 6 +- src/axolotl/utils/callbacks/mlflow_.py | 1 + src/axolotl/utils/callbacks/perplexity.py | 1 + src/axolotl/utils/callbacks/profiler.py | 1 + src/axolotl/utils/chat_templates.py | 1 + src/axolotl/utils/collators/__init__.py | 1 + src/axolotl/utils/collators/core.py | 1 + src/axolotl/utils/collators/mamba.py | 1 + src/axolotl/utils/config/__init__.py | 6 +- .../config/models/input/v0_4_1/__init__.py | 36 +-- .../utils/config/models/internals/__init__.py | 1 + src/axolotl/utils/data/__init__.py | 1 + src/axolotl/utils/distributed.py | 1 + src/axolotl/utils/environment.py | 5 +- src/axolotl/utils/freeze.py | 1 + .../utils/gradient_checkpointing/__init__.py | 1 + src/axolotl/utils/model_shard_quant.py | 15 +- src/axolotl/utils/models.py | 6 +- src/axolotl/utils/optimizers/adopt.py | 1 + src/axolotl/utils/samplers/__init__.py | 1 + src/axolotl/utils/samplers/utils.py | 1 + src/axolotl/utils/schedulers.py | 1 + src/axolotl/utils/trainer.py | 6 +- ...setuptools_axolotl_dynamic_dependencies.py | 1 + .../test_cli_merge_sharded_fsdp_weights.py | 1 + tests/cli/test_cli_sweeps.py | 1 + tests/cli/test_utils.py | 1 + tests/conftest.py | 1 + tests/core/chat/test_messages.py | 1 + tests/e2e/integrations/test_kd.py | 1 + tests/e2e/kernels/test_geglu.py | 1 + tests/e2e/kernels/test_lora.py | 1 + tests/e2e/kernels/test_quantize.py | 1 + tests/e2e/kernels/test_swiglu.py | 1 + tests/e2e/multigpu/test_eval.py | 1 + tests/e2e/multigpu/test_grpo.py | 1 + .../lora_kernels/test_lora_kernel_patching.py | 1 + tests/e2e/patched/test_cli_integrations.py | 1 + tests/e2e/patched/test_unsloth_integration.py | 1 + tests/e2e/patched/test_unsloth_qlora.py | 1 + tests/e2e/test_imports.py | 1 + tests/e2e/utils.py | 1 + tests/integrations/test_liger.py | 1 + .../test_llama_attn_hijack_flash.py | 1 + tests/prompt_strategies/messages/test_chat.py | 1 + tests/prompt_strategies/test_alpaca.py | 1 + .../test_chat_template_utils.py | 1 + .../test_chat_templates_advanced.py | 8 +- tests/prompt_strategies/test_dpo_chatml.py | 1 + .../test_jinja_template_analyzer.py | 1 + tests/prompt_strategies/test_raw_io.py | 1 + tests/test_data.py | 1 + tests/test_dict.py | 1 - tests/test_exact_deduplication.py | 11 +- tests/test_expand_mask.py | 1 + tests/test_lora.py | 1 + tests/test_normalize_config.py | 1 + tests/test_packed_batch_sampler.py | 1 + tests/test_packed_pretraining.py | 1 + tests/test_perplexity.py | 1 + tests/test_schedulers.py | 1 + 132 files changed, 479 insertions(+), 301 deletions(-) create mode 100644 .github/workflows/precommit-autoupdate.yml diff --git a/.github/workflows/precommit-autoupdate.yml b/.github/workflows/precommit-autoupdate.yml new file mode 100644 index 000000000..921742211 --- /dev/null +++ b/.github/workflows/precommit-autoupdate.yml @@ -0,0 +1,49 @@ +name: Pre-commit auto-update + +on: + schedule: + - cron: '0 0 * * 0' # Run weekly + workflow_dispatch: # Manual kickoff + +jobs: + auto-update: + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Update pre-commit hooks + id: update + run: | + pip install pre-commit + pre-commit autoupdate + if [[ -n $(git status --porcelain) ]]; then + echo "changes=true" >> $GITHUB_OUTPUT + git diff .pre-commit-config.yaml > pre-commit-update.diff + fi + + - name: Create Pull Request + if: steps.update.outputs.changes == 'true' + uses: peter-evans/create-pull-request@v6 + with: + token: ${{ secrets.GITHUB_TOKEN }} + branch: update/pre-commit-hooks + delete-branch: true + title: "chore: update pre-commit hooks" + commit-message: "chore: update pre-commit hooks" + body: | + Automated PR to update pre-commit hooks to their latest versions. + +
+ Changes: + + ```diff + ${{ steps.update.outputs.diff }} + ``` +
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95a6e99a0..f627ec13f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -11,23 +11,23 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 25.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 6.0.1 hooks: - id: isort - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.1.2 hooks: - id: flake8 - repo: https://github.com/pylint-dev/pylint - rev: c8c96d20cde3552a79858c7456bb1483bf83d633 + rev: v3.3.6 hooks: - id: pylint - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 + rev: v1.15.0 hooks: - id: mypy additional_dependencies: @@ -36,7 +36,7 @@ repos: 'pydantic>=2.5.3', ] - repo: https://github.com/PyCQA/bandit - rev: 1.7.5 + rev: 1.8.3 hooks: - id: bandit args: [ diff --git a/cicd/multigpu.py b/cicd/multigpu.py index eda589dd1..453e8daee 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -1,6 +1,7 @@ """ - modal application to run axolotl gpu tests in Modal - """ +modal application to run axolotl gpu tests in Modal +""" + # pylint: disable=duplicate-code import os diff --git a/cicd/tests.py b/cicd/tests.py index 41ae2306f..d7cdcc473 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -1,4 +1,5 @@ """Modal app to run axolotl GPU tests""" + # pylint: disable=duplicate-code import os diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py index 8ae1e256d..1a85fcef9 100644 --- a/scripts/chat_datasets.py +++ b/scripts/chat_datasets.py @@ -1,6 +1,7 @@ """ helper script to parse chat datasets into a usable yaml """ + import click import yaml from datasets import load_dataset diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 87f87e575..396f4a655 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -1,4 +1,5 @@ """Script to output the correct installation command for cut-cross-entropy.""" + import importlib.util import sys @@ -17,12 +18,12 @@ if v < V("2.4.0"): cce_spec = importlib.util.find_spec("cut_cross_entropy") -uninstall_prefix = "" +UNINSTALL_PREFIX = "" if cce_spec: if not importlib.util.find_spec("cut_cross_entropy.transformers"): - uninstall_prefix = "pip uninstall -y cut-cross-entropy && " + UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " print( - uninstall_prefix + UNINSTALL_PREFIX + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' ) diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index b879601be..5d6900d3e 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -1,6 +1,7 @@ """ launch axolotl in supported cloud platforms """ + from pathlib import Path from typing import Union diff --git a/src/axolotl/cli/cloud/base.py b/src/axolotl/cli/cloud/base.py index 44d1b0c17..eba8be49a 100644 --- a/src/axolotl/cli/cloud/base.py +++ b/src/axolotl/cli/cloud/base.py @@ -1,6 +1,7 @@ """ base class for cloud platforms from cli """ + from abc import ABC, abstractmethod diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 47bc3221a..ef59ed3d4 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -1,6 +1,7 @@ """ Modal Cloud support from CLI """ + import copy import json import os diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 0f132c133..f46c8efe2 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,4 +1,5 @@ """Click CLI definitions for various axolotl commands.""" + # pylint: disable=redefined-outer-name import logging diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index addfa0ab9..cb61fa371 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -5,7 +5,6 @@ import dataclasses import hashlib import json import logging -import typing from functools import wraps from pathlib import Path from types import NoneType @@ -24,7 +23,7 @@ configure_logging() LOG = logging.getLogger(__name__) -def strip_optional_type(field_type: type | typing._SpecialForm | None): +def strip_optional_type(field_type: type | str | None): """ Extracts the non-`None` type from an `Optional` / `Union` type. diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py index 357e0ec50..d1bdb34db 100644 --- a/src/axolotl/convert.py +++ b/src/axolotl/convert.py @@ -1,6 +1,5 @@ """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" - import json import sys diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py index 315d101a8..04c398fe8 100644 --- a/src/axolotl/core/chat/format/chatml.py +++ b/src/axolotl/core/chat/format/chatml.py @@ -1,6 +1,7 @@ """ ChatML transformation functions for MessageContents """ + from typing import Optional from ..messages import MessageContents, Messages diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py index 17fa7aa8d..a0ce053e5 100644 --- a/src/axolotl/core/chat/format/llama3x.py +++ b/src/axolotl/core/chat/format/llama3x.py @@ -1,6 +1,7 @@ """ Llama 3.x chat formatting functions for MessageContents """ + from typing import Optional from ..messages import MessageContents, Messages diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py index 9efa2353d..0a0f56f3a 100644 --- a/src/axolotl/core/chat/format/shared.py +++ b/src/axolotl/core/chat/format/shared.py @@ -1,6 +1,7 @@ """ shared functions for format transforms """ + from axolotl.core.chat.messages import MessageContents, Messages diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index c879bf477..88ff2b7ad 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -1,6 +1,7 @@ """ internal message representations of chat messages """ + import json from enum import Enum from typing import Any, Callable, List, Optional, Union diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py index ba257071d..724f12866 100644 --- a/src/axolotl/core/datasets/chat.py +++ b/src/axolotl/core/datasets/chat.py @@ -1,6 +1,7 @@ """ chat dataset module """ + import os from typing import Callable, Optional, Union diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py index 98d5f171a..692fe3ebb 100644 --- a/src/axolotl/core/datasets/transforms/chat_builder.py +++ b/src/axolotl/core/datasets/transforms/chat_builder.py @@ -1,6 +1,7 @@ """ This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. """ + from typing import Any, Mapping, Union diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0c9204747..19b947fb2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs = {} if self.cfg.include_tokens_per_second is not None: - training_arguments_kwargs[ - "include_tokens_per_second" - ] = self.cfg.include_tokens_per_second + training_arguments_kwargs["include_tokens_per_second"] = ( + self.cfg.include_tokens_per_second + ) if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True @@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["seed"] = self.cfg.seed if self.cfg.gradient_checkpointing: - training_arguments_kwargs[ - "gradient_checkpointing" - ] = self.cfg.gradient_checkpointing + training_arguments_kwargs["gradient_checkpointing"] = ( + self.cfg.gradient_checkpointing + ) if self.cfg.gradient_checkpointing_kwargs is not None: - training_arguments_kwargs[ - "gradient_checkpointing_kwargs" - ] = self.cfg.gradient_checkpointing_kwargs + training_arguments_kwargs["gradient_checkpointing_kwargs"] = ( + self.cfg.gradient_checkpointing_kwargs + ) if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: @@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed if self.cfg.lr_quadratic_warmup is not None: - training_arguments_kwargs[ - "lr_quadratic_warmup" - ] = self.cfg.lr_quadratic_warmup + training_arguments_kwargs["lr_quadratic_warmup"] = ( + self.cfg.lr_quadratic_warmup + ) if self.cfg.adam_beta1: training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 @@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors if self.cfg.dataloader_pin_memory is not None: - training_arguments_kwargs[ - "dataloader_pin_memory" - ] = self.cfg.dataloader_pin_memory + training_arguments_kwargs["dataloader_pin_memory"] = ( + self.cfg.dataloader_pin_memory + ) if self.cfg.dataloader_num_workers is not None: - training_arguments_kwargs[ - "dataloader_num_workers" - ] = self.cfg.dataloader_num_workers + training_arguments_kwargs["dataloader_num_workers"] = ( + self.cfg.dataloader_num_workers + ) if self.cfg.dataloader_prefetch_factor is not None: - training_arguments_kwargs[ - "dataloader_prefetch_factor" - ] = self.cfg.dataloader_prefetch_factor + training_arguments_kwargs["dataloader_prefetch_factor"] = ( + self.cfg.dataloader_prefetch_factor + ) if self.cfg.dataloader_drop_last is not None: - training_arguments_kwargs[ - "dataloader_drop_last" - ] = self.cfg.dataloader_drop_last + training_arguments_kwargs["dataloader_drop_last"] = ( + self.cfg.dataloader_drop_last + ) elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: training_arguments_kwargs["dataloader_drop_last"] = True if self.cfg.remove_unused_columns is not None: - training_arguments_kwargs[ - "remove_unused_columns" - ] = self.cfg.remove_unused_columns + training_arguments_kwargs["remove_unused_columns"] = ( + self.cfg.remove_unused_columns + ) if not self.cfg.test_datasets and self.cfg.val_set_size == 0: # no eval set, so don't eval @@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.do_causal_lm_eval: training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval if self.cfg.metric_for_best_model: - training_arguments_kwargs[ - "metric_for_best_model" - ] = self.cfg.metric_for_best_model + training_arguments_kwargs["metric_for_best_model"] = ( + self.cfg.metric_for_best_model + ) if self.cfg.greater_is_better: training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better @@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile if self.cfg.torch_compile_backend: - training_arguments_kwargs[ - "torch_compile_backend" - ] = self.cfg.torch_compile_backend + training_arguments_kwargs["torch_compile_backend"] = ( + self.cfg.torch_compile_backend + ) if self.cfg.torch_compile_mode: - training_arguments_kwargs[ - "torch_compile_mode" - ] = self.cfg.torch_compile_mode + training_arguments_kwargs["torch_compile_mode"] = ( + self.cfg.torch_compile_mode + ) # DDP Config if self.cfg.ddp_timeout: @@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.ddp_bucket_cap_mb: training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb if self.cfg.ddp_broadcast_buffers is not None: - training_arguments_kwargs[ - "ddp_broadcast_buffers" - ] = self.cfg.ddp_broadcast_buffers + training_arguments_kwargs["ddp_broadcast_buffers"] = ( + self.cfg.ddp_broadcast_buffers + ) # these are all the "standard" kwargs that are def used training_arguments_kwargs["max_steps"] = ( total_num_steps if self.cfg.max_steps else -1 ) training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len - training_arguments_kwargs[ - "per_device_train_batch_size" - ] = self.cfg.micro_batch_size + training_arguments_kwargs["per_device_train_batch_size"] = ( + self.cfg.micro_batch_size + ) if self.cfg.eval_batch_size: - training_arguments_kwargs[ - "per_device_eval_batch_size" - ] = self.cfg.eval_batch_size + training_arguments_kwargs["per_device_eval_batch_size"] = ( + self.cfg.eval_batch_size + ) if self.cfg.auto_find_batch_size is not None: - training_arguments_kwargs[ - "auto_find_batch_size" - ] = self.cfg.auto_find_batch_size - training_arguments_kwargs[ - "gradient_accumulation_steps" - ] = self.cfg.gradient_accumulation_steps - training_arguments_kwargs[ - "eval_accumulation_steps" - ] = self.cfg.gradient_accumulation_steps + training_arguments_kwargs["auto_find_batch_size"] = ( + self.cfg.auto_find_batch_size + ) + training_arguments_kwargs["gradient_accumulation_steps"] = ( + self.cfg.gradient_accumulation_steps + ) + training_arguments_kwargs["eval_accumulation_steps"] = ( + self.cfg.gradient_accumulation_steps + ) training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["output_dir"] = self.cfg.output_dir @@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" - training_arguments_kwargs[ - "alternate_lr_scheduler_type" - ] = self.cfg.lr_scheduler + training_arguments_kwargs["alternate_lr_scheduler_type"] = ( + self.cfg.lr_scheduler + ) else: training_arguments_kwargs["lr_scheduler_type"] = ( self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" @@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio - training_arguments_kwargs[ - "cosine_constant_lr_ratio" - ] = self.cfg.cosine_constant_lr_ratio + training_arguments_kwargs["cosine_constant_lr_ratio"] = ( + self.cfg.cosine_constant_lr_ratio + ) training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) @@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.eval_sample_packing ) if self.cfg.sample_packing_bin_size is not None: - training_arguments_kwargs[ - "sample_packing_bin_size" - ] = self.cfg.sample_packing_bin_size + training_arguments_kwargs["sample_packing_bin_size"] = ( + self.cfg.sample_packing_bin_size + ) if self.cfg.sample_packing_group_size is not None: - training_arguments_kwargs[ - "sample_packing_group_size" - ] = self.cfg.sample_packing_group_size + training_arguments_kwargs["sample_packing_group_size"] = ( + self.cfg.sample_packing_group_size + ) if self.cfg.sample_packing_eff_est: - training_arguments_kwargs[ - "sample_packing_efficiency" - ] = self.cfg.sample_packing_eff_est + training_arguments_kwargs["sample_packing_efficiency"] = ( + self.cfg.sample_packing_eff_est + ) if self.cfg.relora_steps: training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps - training_arguments_kwargs[ - "relora_warmup_steps" - ] = self.cfg.relora_warmup_steps + training_arguments_kwargs["relora_warmup_steps"] = ( + self.cfg.relora_warmup_steps + ) if self.cfg.relora_anneal_steps: - training_arguments_kwargs[ - "relora_anneal_steps" - ] = self.cfg.relora_anneal_steps + training_arguments_kwargs["relora_anneal_steps"] = ( + self.cfg.relora_anneal_steps + ) if self.cfg.relora_prune_ratio: - training_arguments_kwargs[ - "relora_prune_ratio" - ] = self.cfg.relora_prune_ratio + training_arguments_kwargs["relora_prune_ratio"] = ( + self.cfg.relora_prune_ratio + ) if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers - training_arguments_kwargs[ - "lisa_step_interval" - ] = self.cfg.lisa_step_interval - training_arguments_kwargs[ - "lisa_layers_attribute" - ] = self.cfg.lisa_layers_attribute + training_arguments_kwargs["lisa_step_interval"] = ( + self.cfg.lisa_step_interval + ) + training_arguments_kwargs["lisa_layers_attribute"] = ( + self.cfg.lisa_layers_attribute + ) training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs @@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) if self.cfg.neftune_noise_alpha is not None: - training_arguments_kwargs[ - "neftune_noise_alpha" - ] = self.cfg.neftune_noise_alpha + training_arguments_kwargs["neftune_noise_alpha"] = ( + self.cfg.neftune_noise_alpha + ) trainer_kwargs = {} @@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): importlib.import_module("torchdistx") if self.cfg.optim_target_modules: - training_arguments_kwargs[ - "optim_target_modules" - ] = self.cfg.optim_target_modules + training_arguments_kwargs["optim_target_modules"] = ( + self.cfg.optim_target_modules + ) training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio - training_arguments_kwargs[ - "loraplus_lr_embedding" - ] = self.cfg.loraplus_lr_embedding + training_arguments_kwargs["loraplus_lr_embedding"] = ( + self.cfg.loraplus_lr_embedding + ) training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.accelerator_config: - training_arguments_kwargs[ - "accelerator_config" - ] = self.cfg.accelerator_config + training_arguments_kwargs["accelerator_config"] = ( + self.cfg.accelerator_config + ) if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha @@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.kd_temperature is not None: training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature if self.cfg.kd_zscore_base_temp is not None: - training_arguments_kwargs[ - "kd_zscore_base_temp" - ] = self.cfg.kd_zscore_base_temp + training_arguments_kwargs["kd_zscore_base_temp"] = ( + self.cfg.kd_zscore_base_temp + ) if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs[ - "kd_top_k_before_softmax" - ] = self.cfg.kd_top_k_before_softmax + training_arguments_kwargs["kd_top_k_before_softmax"] = ( + self.cfg.kd_top_k_before_softmax + ) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) if self.cfg.remove_unused_columns is not None: - training_args_kwargs[ - "remove_unused_columns" - ] = self.cfg.remove_unused_columns + training_args_kwargs["remove_unused_columns"] = ( + self.cfg.remove_unused_columns + ) else: training_args_kwargs["remove_unused_columns"] = False if self.cfg.dataloader_pin_memory is not None: - training_args_kwargs[ - "dataloader_pin_memory" - ] = self.cfg.dataloader_pin_memory + training_args_kwargs["dataloader_pin_memory"] = ( + self.cfg.dataloader_pin_memory + ) if self.cfg.dataloader_num_workers is not None: - training_args_kwargs[ - "dataloader_num_workers" - ] = self.cfg.dataloader_num_workers + training_args_kwargs["dataloader_num_workers"] = ( + self.cfg.dataloader_num_workers + ) if self.cfg.dataloader_prefetch_factor is not None: - training_args_kwargs[ - "dataloader_prefetch_factor" - ] = self.cfg.dataloader_prefetch_factor + training_args_kwargs["dataloader_prefetch_factor"] = ( + self.cfg.dataloader_prefetch_factor + ) if self.cfg.gradient_checkpointing: - training_args_kwargs[ - "gradient_checkpointing" - ] = self.cfg.gradient_checkpointing + training_args_kwargs["gradient_checkpointing"] = ( + self.cfg.gradient_checkpointing + ) if self.cfg.gradient_checkpointing_kwargs is not None: - training_args_kwargs[ - "gradient_checkpointing_kwargs" - ] = self.cfg.gradient_checkpointing_kwargs + training_args_kwargs["gradient_checkpointing_kwargs"] = ( + self.cfg.gradient_checkpointing_kwargs + ) else: training_args_kwargs["gradient_checkpointing_kwargs"] = { "use_reentrant": False @@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting if self.cfg.dpo_use_logits_to_keep is not None: - training_args_kwargs[ - "use_logits_to_keep" - ] = self.cfg.dpo_use_logits_to_keep + training_args_kwargs["use_logits_to_keep"] = ( + self.cfg.dpo_use_logits_to_keep + ) for blocklist_key in blocklist_args_kwargs: if blocklist_key in training_args_kwargs: @@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.adapter and self.peft_config: dpo_trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: - dpo_trainer_kwargs[ - "precompute_ref_log_probs" - ] = self.cfg.precompute_ref_log_probs + dpo_trainer_kwargs["precompute_ref_log_probs"] = ( + self.cfg.precompute_ref_log_probs + ) if self.cfg.rl == "grpo": trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls_args = [self.model] diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index c14ed59b5..6570db967 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): "pin_memory": self.args.dataloader_pin_memory, } if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor + dataloader_params["prefetch_factor"] = ( + self.args.dataloader_prefetch_factor + ) sampler = self._get_train_sampler() if isinstance(sampler, BatchSampler): @@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): "pin_memory": self.args.dataloader_pin_memory, } if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor + dataloader_params["prefetch_factor"] = ( + self.args.dataloader_prefetch_factor + ) if isinstance(eval_sampler, BatchSampler): dataloader_params["batch_sampler"] = eval_sampler diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 8187a7fb5..2d6835cf7 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -1,6 +1,7 @@ """ DPO Specific Strategy for training """ + from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index 4cae67d3e..de1758ed0 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -1,6 +1,7 @@ """ Axolotl specific DPO args """ + from dataclasses import dataclass from trl import DPOConfig diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index a1de4cc82..38b657260 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,6 +1,7 @@ """ DPO trainer for axolotl """ + import gc from functools import wraps from typing import Any, Dict, Union diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index ecfc12309..52e6363a2 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -45,9 +45,9 @@ class GRPOStrategy: ) if trl.vllm_gpu_memory_utilization: - grpo_args_kwargs[ - "vllm_gpu_memory_utilization" - ] = trl.vllm_gpu_memory_utilization + grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( + trl.vllm_gpu_memory_utilization + ) if trl.vllm_max_model_len: grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len @@ -86,9 +86,9 @@ class GRPOStrategy: def set_trainer_kwargs(cls, cfg): trainer_kwargs = {} if cfg.trl and cfg.trl.reward_processing_classes: - trainer_kwargs[ - "reward_processing_classes" - ] = cfg.trl.reward_processing_classes + trainer_kwargs["reward_processing_classes"] = ( + cfg.trl.reward_processing_classes + ) return trainer_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index e14e6b0dc..5460edca9 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -1,6 +1,7 @@ """ Axolotl Specific Training Args """ + from dataclasses import dataclass from trl import GRPOConfig diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 6c8f39ac6..663bed094 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,6 +1,7 @@ """ Axolotl GRPO trainer """ + from accelerate.utils import is_peft_model from accelerate.utils.other import is_compiled_module from transformers import PreTrainedModel diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index 57f014bd6..7237e792e 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -1,6 +1,7 @@ """ module for TRL PPO training """ + import torch from tqdm import tqdm from trl import PPOTrainer diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 7cace7643..34a79e646 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -1,6 +1,7 @@ """ extra axolotl specific training args """ + from dataclasses import dataclass, field from typing import Optional diff --git a/src/axolotl/integrations/grokfast/__init__.py b/src/axolotl/integrations/grokfast/__init__.py index 3889e927c..c8c352bbe 100644 --- a/src/axolotl/integrations/grokfast/__init__.py +++ b/src/axolotl/integrations/grokfast/__init__.py @@ -1,6 +1,7 @@ """ Grokfast plugin for Axolotl """ + import logging from transformers.trainer_callback import TrainerCallback diff --git a/src/axolotl/integrations/grokfast/args.py b/src/axolotl/integrations/grokfast/args.py index 4776ae60c..ac91c7395 100644 --- a/src/axolotl/integrations/grokfast/args.py +++ b/src/axolotl/integrations/grokfast/args.py @@ -1,6 +1,7 @@ """ config args for grokfast plugin """ + from typing import Optional from pydantic import BaseModel diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index a88a0dc48..2fbba2c6a 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -26,12 +26,12 @@ class KDArgs(BaseModel): """ kd_trainer: Optional[bool] = None # whether to use KD trainer - kd_ce_alpha: Optional[ - float - ] = None # loss coefficient for cross-entropy loss during KD + kd_ce_alpha: Optional[float] = ( + None # loss coefficient for cross-entropy loss during KD + ) kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_temperature: Optional[float] = None # temperature for sampling during KD kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling - kd_top_k_before_softmax: Optional[ - bool - ] = None # whether to sample top k before softmax during KD + kd_top_k_before_softmax: Optional[bool] = ( + None # whether to sample top k before softmax during KD + ) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index b67dd01e6..b8e1fac52 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin): if "cross_entropy" in liger_fn_sig.parameters: kwargs["cross_entropy"] = cfg.liger_cross_entropy if "fused_linear_cross_entropy" in liger_fn_sig.parameters: - kwargs[ - "fused_linear_cross_entropy" - ] = cfg.liger_fused_linear_cross_entropy + kwargs["fused_linear_cross_entropy"] = ( + cfg.liger_fused_linear_cross_entropy + ) if "rms_norm" in liger_fn_sig.parameters: kwargs["rms_norm"] = cfg.liger_rms_norm if "layer_norm" in liger_fn_sig.parameters: diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py index 79fb27436..c29fd4e79 100644 --- a/src/axolotl/integrations/liger/models/deepseekv2.py +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -1,6 +1,7 @@ """ DeepseekV2 model with LigerFusedLinearCrossEntropyLoss """ + # pylint: disable=duplicate-code from typing import List, Optional, Tuple, Union diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py index 40cec63a4..7ab464c88 100644 --- a/src/axolotl/integrations/liger/models/jamba.py +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -1,6 +1,7 @@ """ Jamba model with LigerFusedLinearCrossEntropyLoss """ + # pylint: disable=duplicate-code from typing import Optional, Tuple, Union diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 0cbc8a49d..8db4dc634 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -1,6 +1,7 @@ """ Module for the Plugin for LM Eval Harness """ + import subprocess # nosec from axolotl.integrations.base import BasePlugin diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py index 721f560e3..d02213177 100644 --- a/src/axolotl/integrations/lm_eval/args.py +++ b/src/axolotl/integrations/lm_eval/args.py @@ -1,6 +1,7 @@ """ Module for handling lm eval harness input arguments. """ + from typing import List, Optional from pydantic import BaseModel diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 4a9bbafe6..19608e1d9 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -1,6 +1,7 @@ """ axolotl CLI for running lm_eval tasks """ + import subprocess # nosec from collections import defaultdict from datetime import datetime diff --git a/src/axolotl/kernels/geglu.py b/src/axolotl/kernels/geglu.py index 4dd70f4cc..0aa035c94 100644 --- a/src/axolotl/kernels/geglu.py +++ b/src/axolotl/kernels/geglu.py @@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ + # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code import torch diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 1f8a8e787..03fca6df4 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models" Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ + # pylint: disable=invalid-name from typing import Callable diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index ea5ecf8e8..b61603fbc 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -1,4 +1,5 @@ """Dequantization utilities for `bitsandbytes` integration.""" + # pylint: disable=invalid-name,global-statement import ctypes diff --git a/src/axolotl/kernels/swiglu.py b/src/axolotl/kernels/swiglu.py index 20c6e87a0..43a798edc 100644 --- a/src/axolotl/kernels/swiglu.py +++ b/src/axolotl/kernels/swiglu.py @@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ + import torch import triton import triton.language as tl diff --git a/src/axolotl/models/mamba/configuration_mamba.py b/src/axolotl/models/mamba/configuration_mamba.py index 5160ee8d7..d6b77b951 100644 --- a/src/axolotl/models/mamba/configuration_mamba.py +++ b/src/axolotl/models/mamba/configuration_mamba.py @@ -1,6 +1,7 @@ """ HF Transformers MambaConfig """ + from transformers import PretrainedConfig diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py index 0b18b716d..c9e8fb5e1 100644 --- a/src/axolotl/monkeypatch/attention/mllama.py +++ b/src/axolotl/monkeypatch/attention/mllama.py @@ -1,6 +1,7 @@ """ Monkeypatch for Vision Llama for FA2 support """ + # pylint: disable=duplicate-code from typing import Optional, Tuple @@ -220,10 +221,10 @@ def patch_mllama(): True ) MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 - MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ - "flash_attention_2" - ] = MllamaTextCrossFlashAttention2 + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = ( + MllamaTextCrossFlashAttention2 + ) # fallback to SDPA - MLLAMA_VISION_ATTENTION_CLASSES[ - "flash_attention_2" - ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] + MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = ( + MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] + ) diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py index 2e9364e3a..df8d106fd 100644 --- a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py +++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py @@ -1,4 +1,5 @@ """monkey patches for the dataset fetcher to handle batches of packed indexes""" + # pylint: disable=protected-access import torch diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index ad0459ccc..998a81027 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -12,7 +12,9 @@ import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import ( + LlamaAttention, +) from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OriginalLlamaDecoderLayer, ) @@ -490,9 +492,11 @@ def flashattn_forward( # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, + query_padding_mask=( + attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None + ), ) output_unpad = flash_attn_varlen_qkvpacked_func( qkv_unpad, @@ -531,9 +535,11 @@ def flashattn_forward( value_states, kvpacked=True, key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, + query_padding_mask=( + attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None + ), ) if q_unpad.dtype != kv_unpad.dtype: kv_unpad = kv_unpad.to(q_unpad.dtype) diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index 5738bb543..0277c212a 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -1,6 +1,7 @@ """ expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf """ + from typing import Optional import torch diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 1cbc4278b..ac9815fce 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -1,4 +1,5 @@ """Flash attention monkey patch for mistral model""" + # pylint: disable=duplicate-code import logging @@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OriginalMistralDecoderLayer, ) -from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv +from transformers.models.mistral.modeling_mistral import ( + apply_rotary_pos_emb, + repeat_kv, +) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids @@ -243,9 +247,11 @@ def flashattn_forward( # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, + query_padding_mask=( + attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None + ), ) output_unpad = flash_attn_varlen_qkvpacked_func( qkv_unpad, @@ -286,9 +292,11 @@ def flashattn_forward( value_states, kvpacked=True, key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, + query_padding_mask=( + attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None + ), ) if q_unpad.dtype != kv_unpad.dtype: kv_unpad = kv_unpad.to(q_unpad.dtype) diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index bb5afe847..5b8054000 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -1,6 +1,7 @@ """ Patches to support multipack for mixtral """ + import torch diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 1dd758ec5..822fd4465 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -1,4 +1,5 @@ """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" + import glob import json import logging @@ -411,7 +412,10 @@ def merge_and_save( if shard_path.endswith(".safetensors"): in_tensors = st.load_file(str(Path(model_src) / shard_path)) else: - in_tensors = torch.load(Path(model_src) / shard_path) + in_tensors = torch.load( + Path(model_src) / shard_path, + weights_only=True, # to prevent arbitrary code execution + ) if "state_dict" in in_tensors: in_tensors = in_tensors["state_dict"] diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 67e9337e3..c60302111 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -17,7 +17,7 @@ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # pylint: disable=duplicate-code -""" PyTorch StableLM Epoch model. """ +"""PyTorch StableLM Epoch model.""" import importlib import math from typing import Optional, Tuple, Union diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 00c2dfebc..1cbfefa5b 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -1,6 +1,7 @@ """ fix for FSDP optimizer save in trainer w 4.47.0 """ + import inspect import logging diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index c2772b471..43496c7c8 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -1,6 +1,7 @@ """ Shared utils for the monkeypatches """ + import re from typing import Optional, Tuple diff --git a/src/axolotl/monkeypatch/xformers_/__init__.py b/src/axolotl/monkeypatch/xformers_/__init__.py index bddc036b2..a052ea49e 100644 --- a/src/axolotl/monkeypatch/xformers_/__init__.py +++ b/src/axolotl/monkeypatch/xformers_/__init__.py @@ -1,6 +1,7 @@ """ Fused MLP layer for incrementally improved training efficiency """ + import torch from transformers.models.llama.modeling_llama import LlamaMLP from xformers.ops import SwiGLU diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 8c8cc0743..6873c8e08 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -1,6 +1,7 @@ """ Prompt strategies loader for alpaca instruction datasets with system prompts """ + from typing import Generator, Tuple, Union from axolotl.prompt_tokenizers import PromptTokenizingStrategy diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index 3285e667c..62a4b90b2 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -1,6 +1,7 @@ """ Basic completion text """ + from collections import defaultdict from typing import Any, Dict, Generator, Optional, Tuple diff --git a/src/axolotl/prompt_strategies/context_qa.py b/src/axolotl/prompt_strategies/context_qa.py index f87dd8b5c..aac44e0b2 100644 --- a/src/axolotl/prompt_strategies/context_qa.py +++ b/src/axolotl/prompt_strategies/context_qa.py @@ -1,4 +1,5 @@ """Module containing the classes for Context QA Prompt Tokenization Strategies""" + from typing import Tuple from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py index 7f5e6eb64..f67125682 100644 --- a/src/axolotl/prompt_strategies/dpo/__init__.py +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -1,6 +1,7 @@ """ module for DPO style dataset transform strategies """ + from functools import partial from ..base import load as load_base diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 5043a501e..34a54aaa0 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -33,9 +33,9 @@ def default( f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" return sample @@ -52,9 +52,9 @@ def argilla_chat( """ def transform_fn(sample): - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" return sample @@ -78,9 +78,9 @@ def icr( f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>" return sample @@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>" return sample @@ -120,9 +120,9 @@ def prompt_pairs( f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>" return sample @@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" return sample diff --git a/src/axolotl/prompt_strategies/dpo/llama3.py b/src/axolotl/prompt_strategies/dpo/llama3.py index d10aa223b..eed420017 100644 --- a/src/axolotl/prompt_strategies/dpo/llama3.py +++ b/src/axolotl/prompt_strategies/dpo/llama3.py @@ -34,9 +34,9 @@ def default( f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" return sample @@ -53,9 +53,9 @@ def argilla_chat( """ def transform_fn(sample): - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" return sample @@ -79,9 +79,9 @@ def icr( f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>" return sample @@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>" return sample @@ -121,9 +121,9 @@ def prompt_pairs( f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>" return sample @@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" return sample diff --git a/src/axolotl/prompt_strategies/input_output.py b/src/axolotl/prompt_strategies/input_output.py index fe14f039c..8be745b20 100644 --- a/src/axolotl/prompt_strategies/input_output.py +++ b/src/axolotl/prompt_strategies/input_output.py @@ -1,4 +1,5 @@ """Module for plain input/output prompt pairs""" + from typing import Generator, Tuple from axolotl.prompt_tokenizers import PromptTokenizingStrategy diff --git a/src/axolotl/prompt_strategies/jinja_template_analyzer.py b/src/axolotl/prompt_strategies/jinja_template_analyzer.py index 01cbd3ad2..a5f89cfe5 100644 --- a/src/axolotl/prompt_strategies/jinja_template_analyzer.py +++ b/src/axolotl/prompt_strategies/jinja_template_analyzer.py @@ -1,4 +1,5 @@ """Module for inspect jinja templates for the variables they use""" + from typing import Dict, Optional, Set, TypedDict, Union from jinja2 import Environment, meta, nodes diff --git a/src/axolotl/prompt_strategies/kto/chatml.py b/src/axolotl/prompt_strategies/kto/chatml.py index 46c305f83..97ae59ed5 100644 --- a/src/axolotl/prompt_strategies/kto/chatml.py +++ b/src/axolotl/prompt_strategies/kto/chatml.py @@ -1,6 +1,7 @@ """ KTO strategies for chatml """ + # pylint: disable=duplicate-code @@ -15,9 +16,9 @@ def argilla( f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["completion"] = f"{sample['completion']}<|im_end|>" return sample @@ -33,9 +34,9 @@ def argilla_chat( """ def transform_fn(sample): - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" return sample @@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["completion"] = f"{sample['completion']}<|im_end|>" return sample @@ -74,9 +75,9 @@ def prompt_pairs( f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["completion"] = f"{sample['completion']}<|im_end|>" return sample @@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" ) else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["prompt"] = ( + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) sample["completion"] = f"{sample['completion']}<|im_end|>" return sample diff --git a/src/axolotl/prompt_strategies/kto/llama3.py b/src/axolotl/prompt_strategies/kto/llama3.py index 795d343fe..fde3c2ed4 100644 --- a/src/axolotl/prompt_strategies/kto/llama3.py +++ b/src/axolotl/prompt_strategies/kto/llama3.py @@ -1,6 +1,7 @@ """ KTO strategies for llama-3 chat template """ + # pylint: disable=duplicate-code @@ -15,9 +16,9 @@ def argilla( f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["completion"] = f"{sample['completion']}<|eot_id|>" return sample @@ -33,9 +34,9 @@ def argilla_chat( """ def transform_fn(sample): - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" return sample @@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["completion"] = f"{sample['completion']}<|eot_id|>" return sample @@ -74,9 +75,9 @@ def prompt_pairs( f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["completion"] = f"{sample['completion']}<|eot_id|>" return sample @@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: - sample[ - "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["prompt"] = ( + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) sample["completion"] = f"{sample['completion']}<|eot_id|>" return sample diff --git a/src/axolotl/prompt_strategies/kto/user_defined.py b/src/axolotl/prompt_strategies/kto/user_defined.py index 7e5458bb7..7c68a3000 100644 --- a/src/axolotl/prompt_strategies/kto/user_defined.py +++ b/src/axolotl/prompt_strategies/kto/user_defined.py @@ -1,6 +1,7 @@ """ User-defined KTO strategies """ + # pylint: disable=duplicate-code diff --git a/src/axolotl/prompt_strategies/messages/chat.py b/src/axolotl/prompt_strategies/messages/chat.py index 52124407e..eaed2396a 100644 --- a/src/axolotl/prompt_strategies/messages/chat.py +++ b/src/axolotl/prompt_strategies/messages/chat.py @@ -1,6 +1,7 @@ """ Chat dataset wrapping strategy for new internal messages representations """ + from typing import Any, Callable, Dict, Optional from axolotl.core.datasets.chat import TokenizedChatDataset diff --git a/src/axolotl/prompt_strategies/orcamini.py b/src/axolotl/prompt_strategies/orcamini.py index 04ce5767d..1a694cf1d 100644 --- a/src/axolotl/prompt_strategies/orcamini.py +++ b/src/axolotl/prompt_strategies/orcamini.py @@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:". Not suited/tested for multiple-turn conversations without further adjustments. """ + from typing import Generator, Union from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index e53a54748..fdee28ea1 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -1,4 +1,5 @@ """chatml prompt tokenization strategy for ORPO""" + from typing import Any, Dict, Generator, List, Optional, Tuple from pydantic import BaseModel diff --git a/src/axolotl/prompt_strategies/pretrain.py b/src/axolotl/prompt_strategies/pretrain.py index 8430a7fca..cfb6cb8be 100644 --- a/src/axolotl/prompt_strategies/pretrain.py +++ b/src/axolotl/prompt_strategies/pretrain.py @@ -1,4 +1,5 @@ """pretraining prompt strategies""" + from typing import Generator from transformers import BatchEncoding diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 1ceb5babd..ff486db29 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -406,9 +406,7 @@ def handle_untrained_tokens_fix( ) -def setup_model_and_trainer( - cfg: DictDefault, dataset_meta: TrainDatasetMeta -) -> tuple[ +def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ HFRLTrainerBuilder | HFCausalTrainerBuilder, PeftModel | PreTrainedModel, PreTrainedTokenizer, diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 35ea14551..ffa528cc9 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf(): torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) if torch_major == 2 and torch_minor >= 2: if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ[ - "PYTORCH_CUDA_ALLOC_CONF" - ] = "expandable_segments:True,roundup_power2_divisions:16" + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( + "expandable_segments:True,roundup_power2_divisions:16" + ) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 3d338aff1..d1e972c81 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -1,4 +1,5 @@ """Benchmarking and measurement utilities""" + import functools import torch diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 9ca0e84fe..47c77619a 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer): bench_refs.extend(combined_bench_names[bench_name]["refs"]) bench_preds.extend(combined_bench_names[bench_name]["preds"]) if not pd.isna(bench_score): - results[ - f"{bench_split}_bench_accuracy_{bench_name}" - ] = bench_score + results[f"{bench_split}_bench_accuracy_{bench_name}"] = ( + bench_score + ) bench_scores.append(bench_score) else: results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index fcbb88edc..47679001f 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,4 +1,5 @@ """MLFlow module for trainer callbacks""" + import logging from shutil import copyfile from tempfile import NamedTemporaryFile diff --git a/src/axolotl/utils/callbacks/perplexity.py b/src/axolotl/utils/callbacks/perplexity.py index d3a362c4c..a5b39c304 100644 --- a/src/axolotl/utils/callbacks/perplexity.py +++ b/src/axolotl/utils/callbacks/perplexity.py @@ -1,4 +1,5 @@ """callback to calculate perplexity as an evaluation metric.""" + from typing import Dict, List, Optional import torch diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 861696332..36604813f 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -1,6 +1,7 @@ """ HF Trainer callback for creating pytorch profiling snapshots """ + from pathlib import Path from pickle import dump # nosec B403 diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 9f6550689..d3c88334b 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,6 +2,7 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ + import logging from typing import TYPE_CHECKING, Any, Dict, Optional diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 93502b67d..8c60f223c 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -1,6 +1,7 @@ """ shared axolotl collators for multipack, mamba, multimodal """ + from .batching import ( # noqa: F401 BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, diff --git a/src/axolotl/utils/collators/core.py b/src/axolotl/utils/collators/core.py index 0eae0c3bd..542328127 100644 --- a/src/axolotl/utils/collators/core.py +++ b/src/axolotl/utils/collators/core.py @@ -1,4 +1,5 @@ """ basic shared collator constants """ + IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/collators/mamba.py b/src/axolotl/utils/collators/mamba.py index 0c4a22fcc..33f3991cb 100644 --- a/src/axolotl/utils/collators/mamba.py +++ b/src/axolotl/utils/collators/mamba.py @@ -1,6 +1,7 @@ """ collators for Mamba """ + from dataclasses import dataclass from typing import Dict, Sequence diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 421f3b649..b7096eeab 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -18,7 +18,11 @@ from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.config.models.input.v0_4_1 import ( AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset +from axolotl.utils.config.models.input.v0_4_1 import ( + DPODataset, + KTODataset, + SFTDataset, +) from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2fa86eced..f1c514a7c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -200,12 +200,12 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None field_messages: Optional[str] = None - message_field_role: Optional[ - str - ] = None # deprecated, use message_property_mappings - message_field_content: Optional[ - str - ] = None # deprecated, use message_property_mappings + message_field_role: Optional[str] = ( + None # deprecated, use message_property_mappings + ) + message_field_content: Optional[str] = ( + None # deprecated, use message_property_mappings + ) message_property_mappings: Optional[Dict[str, str]] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None @@ -505,9 +505,9 @@ class HyperparametersConfig(BaseModel): embedding_lr: Optional[float] = None embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 - optimizer: Optional[ - Union[OptimizerNames, CustomSupportedOptimizers] - ] = OptimizerNames.ADAMW_TORCH_FUSED + optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = ( + OptimizerNames.ADAMW_TORCH_FUSED + ) optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, json_schema_extra={"description": "Optional arguments to supply to optimizer."}, @@ -699,9 +699,9 @@ class AxolotlInputConfig( reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None num_labels: Optional[int] = None - dpo_use_weighting: Optional[ - bool - ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + dpo_use_weighting: Optional[bool] = ( + None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + ) dpo_use_logits_to_keep: Optional[bool] = None datasets: Optional[ @@ -780,9 +780,9 @@ class AxolotlInputConfig( # torch_dtype: Optional[torch.dtype] - gradient_checkpointing: Optional[ - Union[Literal["unsloth", "offload"], bool] - ] = Field(default=False) + gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = ( + Field(default=False) + ) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None unfrozen_parameters: Optional[List[str]] = None @@ -894,9 +894,9 @@ class AxolotlInputConfig( kto_undesirable_weight: Optional[float] = None rl_beta: Optional[float] = None - max_memory: Optional[ - Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] - ] = None + max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = ( + None + ) gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/config/models/internals/__init__.py index 7b4a12e03..692dee833 100644 --- a/src/axolotl/utils/config/models/internals/__init__.py +++ b/src/axolotl/utils/config/models/internals/__init__.py @@ -1,4 +1,5 @@ """module for gpu capabilities""" + from typing import Optional from pydantic import BaseModel, Field diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 7f90bf3cb..8dedcbe69 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,6 +1,7 @@ """ Data processing modules """ + from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 7d6cd597a..7c671abfe 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,6 +1,7 @@ """ utility helpers for distributed checks """ + import os import pickle # nosec from contextlib import contextmanager diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 381fec84c..1cc609a68 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -1,10 +1,13 @@ """ utils to get GPU info for the current environment """ + from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, ) -from accelerate.utils.environment import get_gpu_info +from accelerate.utils.environment import ( + get_gpu_info, +) def check_cuda_p2p_ib_support(): diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index f96bc120d..7199eaa36 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -1,6 +1,7 @@ """ module to freeze/unfreeze parameters by name """ + import logging import re from typing import Callable, List, Tuple, Union diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index 8bbf878ad..62fd34b59 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -1,4 +1,5 @@ """custom checkpointing utils""" + from axolotl.utils.gradient_checkpointing.unsloth import ( Unsloth_Offloaded_Gradient_Checkpointer, ) diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index ecbe86613..5c5006eda 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -1,6 +1,7 @@ """ module to handle loading model on cpu/meta device for FSDP """ + import os import time from typing import List, Optional, Type, Union @@ -45,13 +46,13 @@ def _replace_linear( if isinstance(module, torch.nn.Linear) and name not in skip_modules: if issubclass(linear_replacement, Linear4bit): - model._modules[ # pylint: disable=protected-access - name - ] = linear_replacement( - module.in_features, - module.out_features, - module.bias is not None, - **kwargs, + model._modules[name] = ( # pylint: disable=protected-access + linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + **kwargs, + ) ) else: raise ValueError( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 93d0f13c0..44f570b88 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -741,9 +741,9 @@ class ModelLoader: ) else: if self.cfg.gptq_disable_exllama is not None: - self.model_config.quantization_config[ - "disable_exllama" - ] = self.cfg.gptq_disable_exllama + self.model_config.quantization_config["disable_exllama"] = ( + self.cfg.gptq_disable_exllama + ) self.model_kwargs["quantization_config"] = GPTQConfig( **self.model_config.quantization_config ) diff --git a/src/axolotl/utils/optimizers/adopt.py b/src/axolotl/utils/optimizers/adopt.py index 36217730b..6f064abbf 100644 --- a/src/axolotl/utils/optimizers/adopt.py +++ b/src/axolotl/utils/optimizers/adopt.py @@ -4,6 +4,7 @@ Copied from https://github.com/iShohei220/adopt ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024) Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka """ + # mypy: ignore-errors # pylint: skip-file # flake8: noqa diff --git a/src/axolotl/utils/samplers/__init__.py b/src/axolotl/utils/samplers/__init__.py index 96e00a5d2..eb6dc8ab9 100644 --- a/src/axolotl/utils/samplers/__init__.py +++ b/src/axolotl/utils/samplers/__init__.py @@ -1,5 +1,6 @@ """ axolotl samplers module """ + from .multipack import MultipackBatchSampler # noqa: F401 from .utils import get_dataset_lengths # noqa: F401 diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index 09f1b081c..a93e84748 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -1,6 +1,7 @@ """ helper util to calculate dataset lengths """ + import numpy as np diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index 6f057fbd9..e9989b1af 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,4 +1,5 @@ """Module for custom LRScheduler class""" + import math from functools import partial diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8cee3d124..090e677a6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -538,9 +538,9 @@ def setup_fsdp_envs(cfg): if cfg.fsdp_config.fsdp_auto_wrap_policy: os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: - os.environ[ - "FSDP_TRANSFORMER_CLS_TO_WRAP" - ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ( + cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap + ) def prepare_optim_env(cfg): diff --git a/src/setuptools_axolotl_dynamic_dependencies.py b/src/setuptools_axolotl_dynamic_dependencies.py index 2f10efde6..02a5b8083 100644 --- a/src/setuptools_axolotl_dynamic_dependencies.py +++ b/src/setuptools_axolotl_dynamic_dependencies.py @@ -1,6 +1,7 @@ """ dynamic requirements for axolotl """ + import platform import re from importlib.metadata import PackageNotFoundError, version diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index 18589a80d..ec96b4ed4 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" + # pylint: disable=duplicate-code from unittest.mock import patch diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py index 61c886e80..40b360717 100644 --- a/tests/cli/test_cli_sweeps.py +++ b/tests/cli/test_cli_sweeps.py @@ -1,6 +1,7 @@ """ unit tests for generating sweep configurations """ + from axolotl.cli.main import generate_sweep_configs diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index ecb0025e4..2dab5bba9 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI utils.""" + # pylint: disable=redefined-outer-name import json diff --git a/tests/conftest.py b/tests/conftest.py index 2505e6de9..75b12a036 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """ shared pytest fixtures """ + import functools import importlib import shutil diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py index b3be56c59..6a69d74e7 100644 --- a/tests/core/chat/test_messages.py +++ b/tests/core/chat/test_messages.py @@ -1,6 +1,7 @@ """ Tests for the chat messages module """ + import unittest import pytest diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index a90b48d67..4f8cde1d7 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -1,6 +1,7 @@ """ e2e tests for kd trainer support in Axolotl """ + from pathlib import Path import pytest diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py index c720bbce7..005a1935d 100644 --- a/tests/e2e/kernels/test_geglu.py +++ b/tests/e2e/kernels/test_geglu.py @@ -1,4 +1,5 @@ """Tests for GEGLU activation function Triton kernels.""" + # pylint: disable=duplicate-code import torch diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index c8becf2da..5ad186cbf 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -1,4 +1,5 @@ """Tests for LoRA custom autograd.""" + # pylint: disable=invalid-name,redefined-outer-name import pytest diff --git a/tests/e2e/kernels/test_quantize.py b/tests/e2e/kernels/test_quantize.py index e4beb846e..ea91407ef 100644 --- a/tests/e2e/kernels/test_quantize.py +++ b/tests/e2e/kernels/test_quantize.py @@ -1,4 +1,5 @@ """Tests for quantization utility functions.""" + # pylint: disable=invalid-name import torch diff --git a/tests/e2e/kernels/test_swiglu.py b/tests/e2e/kernels/test_swiglu.py index 3717402de..60fdafb79 100644 --- a/tests/e2e/kernels/test_swiglu.py +++ b/tests/e2e/kernels/test_swiglu.py @@ -1,4 +1,5 @@ """Tests for SwiGLU activation function Triton kernels.""" + # pylint: disable=duplicate-code import torch diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index 09561bf26..586da8577 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -1,6 +1,7 @@ """ E2E tests for multigpu eval """ + import logging import os from pathlib import Path diff --git a/tests/e2e/multigpu/test_grpo.py b/tests/e2e/multigpu/test_grpo.py index d2b84994b..bb99581ad 100644 --- a/tests/e2e/multigpu/test_grpo.py +++ b/tests/e2e/multigpu/test_grpo.py @@ -1,6 +1,7 @@ """ GRPO test suite """ + import random from pathlib import Path diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 4e3373367..89f80951b 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -1,4 +1,5 @@ """Integration tests for LoRA activation and attention kernels.""" + # pylint: disable=redefined-outer-name import pytest diff --git a/tests/e2e/patched/test_cli_integrations.py b/tests/e2e/patched/test_cli_integrations.py index ce9396d5f..6c908faf1 100644 --- a/tests/e2e/patched/test_cli_integrations.py +++ b/tests/e2e/patched/test_cli_integrations.py @@ -1,6 +1,7 @@ """ test cases to make sure the plugin args are loaded from the config file """ + from pathlib import Path import yaml diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index 403d26147..4cd97c894 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -1,4 +1,5 @@ """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" + import unittest import pytest diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index da5eaffb6..4cea0d26f 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -1,6 +1,7 @@ """ e2e tests for unsloth qlora """ + import logging import os diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py index f186eaac4..fc0843479 100644 --- a/tests/e2e/test_imports.py +++ b/tests/e2e/test_imports.py @@ -1,6 +1,7 @@ """ test module to import various submodules that have historically broken due to dependency issues """ + import unittest diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index ff96f1f58..2b218fbf5 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -1,6 +1,7 @@ """ helper utils for tests """ + import os import shutil import tempfile diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index c75bc1305..cbe1408b8 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -1,6 +1,7 @@ """ config validation tests for swiglu args """ + # pylint: disable=duplicate-code import logging from typing import Optional diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py index 4521cd07b..08425d4dc 100644 --- a/tests/monkeypatch/test_llama_attn_hijack_flash.py +++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py @@ -1,6 +1,7 @@ """ Unit tests for the monkeypatch utils """ + import unittest import torch diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py index 96c4b6cbb..2681bb743 100644 --- a/tests/prompt_strategies/messages/test_chat.py +++ b/tests/prompt_strategies/messages/test_chat.py @@ -1,6 +1,7 @@ """ tests for chat_template prompt strategy """ + # pylint: disable=duplicate-code import logging import unittest diff --git a/tests/prompt_strategies/test_alpaca.py b/tests/prompt_strategies/test_alpaca.py index 51dd5900b..9e425e0df 100644 --- a/tests/prompt_strategies/test_alpaca.py +++ b/tests/prompt_strategies/test_alpaca.py @@ -1,6 +1,7 @@ """ Test module for alpaca integration w chatml """ + import pytest from datasets import Dataset from tokenizers import AddedToken diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py index b63c9aa17..66bcb547d 100644 --- a/tests/prompt_strategies/test_chat_template_utils.py +++ b/tests/prompt_strategies/test_chat_template_utils.py @@ -1,6 +1,7 @@ """ Tests for utils in axolotl.utils.chat_templates """ + import unittest import pytest diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 7f3096eb0..69031bd65 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -920,9 +920,11 @@ class TestChatTemplateConfigurations: ) variables = prompter.get_chat_template_msg_variables( - actual_jinja_template - if actual_jinja_template - else actual_tokenizer.get_chat_template(), + ( + actual_jinja_template + if actual_jinja_template + else actual_tokenizer.get_chat_template() + ), "messages", ) diff --git a/tests/prompt_strategies/test_dpo_chatml.py b/tests/prompt_strategies/test_dpo_chatml.py index 34c29275b..93793b2c5 100644 --- a/tests/prompt_strategies/test_dpo_chatml.py +++ b/tests/prompt_strategies/test_dpo_chatml.py @@ -1,6 +1,7 @@ """ Tests for loading DPO preference datasets with chatml formatting """ + import unittest import pytest diff --git a/tests/prompt_strategies/test_jinja_template_analyzer.py b/tests/prompt_strategies/test_jinja_template_analyzer.py index 004f81099..f666c738c 100644 --- a/tests/prompt_strategies/test_jinja_template_analyzer.py +++ b/tests/prompt_strategies/test_jinja_template_analyzer.py @@ -1,6 +1,7 @@ """ tests for jinja_template_analyzer """ + import logging import pytest diff --git a/tests/prompt_strategies/test_raw_io.py b/tests/prompt_strategies/test_raw_io.py index 967de169f..082e31ee6 100644 --- a/tests/prompt_strategies/test_raw_io.py +++ b/tests/prompt_strategies/test_raw_io.py @@ -1,6 +1,7 @@ """ Test module for raw i/o data for prompts """ + import pytest from datasets import Dataset from tokenizers import AddedToken diff --git a/tests/test_data.py b/tests/test_data.py index e156e1f3c..141f3ed21 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,6 +1,7 @@ """ test module for the axolotl.utils.data module """ + import unittest from transformers import LlamaTokenizer diff --git a/tests/test_dict.py b/tests/test_dict.py index 2007cb085..0bcf8ca7b 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -1,6 +1,5 @@ """Module for testing DictDefault class""" - import unittest import pytest diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index bc0734ed3..3fc315b2e 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -3,6 +3,7 @@ Test suite for functions in the axolotl.utils.data.utils module, focusing on the Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command. """ + import hashlib import unittest from unittest.mock import patch @@ -386,11 +387,11 @@ class TestWrongCollisions(unittest.TestCase): @patch( "axolotl.utils.data.utils.sha256", - side_effect=lambda x: hashlib.sha256( - "forced_collision_hash".encode("utf-8") - ).hexdigest() - if "sample 5" in x - else hashlib.sha256(x.encode("utf-8")).hexdigest(), + side_effect=lambda x: ( + hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest() + if "sample 5" in x + else hashlib.sha256(x.encode("utf-8")).hexdigest() + ), ) def test_deduplication_wrong_collision_train_eval(self, _mock_sha256): dedup_train, dedup_eval, _ = deduplicate_and_log_datasets( diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py index 01241c295..1c69ca234 100644 --- a/tests/test_expand_mask.py +++ b/tests/test_expand_mask.py @@ -1,6 +1,7 @@ """ Unit tests for the monkey patch for expand mask to handle packed sequences """ + import unittest import torch diff --git a/tests/test_lora.py b/tests/test_lora.py index b917ff3f9..540371bef 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,6 +1,7 @@ """ tests for loading loras """ + from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 0d663183d..c8ca3e550 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -1,6 +1,7 @@ """ Test classes for checking functionality of the cfg normalization """ + import unittest from unittest.mock import patch diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index b52320e2a..55a0afaec 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -1,4 +1,5 @@ """Module for testing streaming dataset sequence packing""" + import pytest from datasets import concatenate_datasets, load_dataset from torch.utils.data import DataLoader, RandomSampler diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 9f9ae60fb..71c9a6861 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -1,4 +1,5 @@ """Module for testing streaming dataset sequence packing""" + import functools import unittest diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index b32cd5283..9a1c9b223 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -1,4 +1,5 @@ """unit tests for perplexity eval callback""" + # pylint: disable=redefined-outer-name from pytest import fixture diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index bd37bf01d..92664cca8 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -1,6 +1,7 @@ """ test module for the axolotl.utis.data module """ + import unittest import torch