Compare commits

...

10 Commits

Author SHA1 Message Date
NanoCode012
e37a768960 feat: add baseten to lmeval 2025-08-29 18:02:26 +07:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
Dan Saunders
79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00
Dan Saunders
eea7a006e1 make multipack sampler patch explicit (#3096)
* make multipack sampler patch explicit

* combining
2025-08-22 14:29:10 -04:00
308 changed files with 11591 additions and 11492 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615
skips = B101,B615,B102,B110

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
auto_incremental_review: false
chat:
auto_reply: true

View File

@@ -1,5 +0,0 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,6 +298,12 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -334,10 +340,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,4 +0,0 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,22 +10,12 @@ repos:
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 25.1.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.9
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
hooks:

View File

@@ -1,15 +0,0 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -2,8 +2,6 @@
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -63,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)
@app.function(

View File

@@ -1,7 +1,5 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -70,4 +68,4 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)

View File

@@ -47,7 +47,6 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
@@ -124,7 +123,6 @@ class QuartoGenerator:
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
@@ -318,7 +316,6 @@ class QuartoGenerator:
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
@@ -503,7 +500,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception: # pylint: disable=broad-exception-caught
except Exception:
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
@@ -607,7 +604,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e: # pylint: disable=broad-exception-caught
except Exception as e:
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -26,3 +26,34 @@ include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -2,8 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
f"No conversation field found in dataset: {', '.join(feature_keys)}"
)
ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
f"No role field found in messages: {', '.join(message_fields)}"
)
for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
f"No content field found in messages: {', '.join(message_fields)}"
)
ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -1,11 +1,10 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
global HAS_PRINTED_LOGO
if HAS_PRINTED_LOGO:
return
if is_main_process():

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}
@@ -58,8 +67,16 @@ def do_cli_lm_eval(
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,68 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)
def lm_eval(
self,
config_yaml: str,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
)
shutil.copyfile(
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl lm-eval eval.yaml

View File

@@ -0,0 +1,81 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = (
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
)
node_count = (
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
)
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
# launcher = cloud_config.get("launcher", "accelerate")
# launcher_args = cloud_config.get("launcher_args", [])
script_name = "eval.sh"
# launcher_args_str = ""
# if launcher_args:
# launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
# "AXOLOTL_LAUNCHER": launcher,
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,77 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)
# Commit writes to volume.
if volumes:
@@ -130,7 +130,6 @@ class ModalCloud(Cloud):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
@@ -177,8 +176,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -187,7 +186,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
def get_train_gpu(self):
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
@@ -277,7 +276,7 @@ def _train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs, # pylint: disable=unused-argument
**kwargs,
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:

View File

@@ -210,7 +210,7 @@ def load_cfg(
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
except:
gpu_version = None
prepare_plugins(cfg)

View File

@@ -28,7 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -49,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -35,7 +35,7 @@ def get_multi_line_input() -> str:
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
instruction += line
return instruction
@@ -167,7 +167,6 @@ def do_inference_gradio(
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
@@ -252,7 +251,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -1,7 +1,5 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os
import subprocess # nosec B404
from typing import Literal, Optional

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
def commit_tensor(self, read_item, tensor):
tensor.copy_(tensor.to(torch.bfloat16))
@@ -59,10 +59,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
dist_cp_format_utils._load_state_dict(
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
planner=BFloat16CastPlanner(),
no_dist=True,
)
@@ -191,7 +191,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -73,7 +73,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
except Exception: # nosec B110
pass
# fmt: on
@@ -95,7 +95,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)

View File

@@ -84,5 +84,6 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -59,7 +59,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -65,7 +65,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type == bool:
if field_type is bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -103,7 +103,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type == bool:
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(

View File

@@ -49,7 +49,10 @@ def generate_sweep_configs(
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
@@ -58,7 +61,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
for param_name, param_value in zip(param_names, reg_combo, strict=False):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)

View File

@@ -95,7 +95,6 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",

View File

@@ -39,7 +39,7 @@ def do_vllm_serve(
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,7 +68,6 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets

View File

@@ -67,9 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -84,9 +84,7 @@ def create_causal_mask(
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
@@ -103,9 +101,7 @@ def create_causal_mask(
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -38,13 +36,14 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
import torch._dynamo
class TrainerBuilderBase(abc.ABC):
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
@@ -260,14 +265,14 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.dion import (
DionOptimizerFactory,
)
@@ -414,12 +419,8 @@ class TrainerBuilderBase(abc.ABC):
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.accumulated_cache_size_limit = 256
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
@@ -516,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -344,16 +344,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
training_args = training_args_cls(
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
training_args.run_name = None
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
@@ -406,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -168,16 +168,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
training_args = training_args_cls(
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
training_args.run_name = None
return training_args, trainer_kwargs

View File

@@ -10,7 +10,7 @@ from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
message_index: Optional[int] = None,
) -> Messages:
if message.is_chat_formatted:
return message

View File

@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
ipython = (
# for responses from builtin tools
"ipython"
)
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
special_token = "special_token" # nosec B105
text = "text"
image = "image"
audio = "audio"
tool_call = "tool_call"
tool_response = "tool_response"
class SpecialToken(str, Enum):
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
bos_token = "bos_token" # nosec B105
eos_token = "eos_token" # nosec B105
class ToolCallFunction(BaseModel):
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
id: Optional[str] = None
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
id: Optional[str] = None
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}

View File

@@ -1,23 +1,17 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
This module contains a function that builds a transform that takes a row from the
dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
from typing import Any, Mapping
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
def chat_message_transform_builder(
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: str | list[str] | None = None, # commonly "content"
message_field_training: str | list[str] | None = None, # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -39,6 +33,12 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value
A function that takes a list of conversations and returns a list of messages.
"""
if message_field_training is None:
message_field_training = ["train", "weight"]
if message_field_content is None:
message_field_content = ["value", "text", "content"]
if message_field_role is None:
message_field_role = ["role", "from"]
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer

View File

@@ -1,7 +1,5 @@
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
import os
@@ -44,6 +42,7 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -65,6 +64,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -80,7 +88,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -285,9 +292,9 @@ class AxolotlTrainer(
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
self._eval_dataloaders = {dataloader_key: dataloader}
# fmt: on
return self.accelerator.prepare(dataloader)
@@ -329,6 +336,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -443,7 +461,7 @@ class AxolotlTrainer(
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
num_items_in_batch=None,
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
@@ -524,15 +542,10 @@ class AxolotlTrainer(
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
AcceleratorState._reset_state(reset_partial_state=True)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -540,7 +553,6 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
@@ -581,12 +593,19 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -662,6 +681,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -101,11 +101,11 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
self.loss_type = loss_type
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -128,9 +128,7 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -151,7 +149,7 @@ class GRPOStrategy:
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
def get_collator(cls, *args, **kwargs):
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None

View File

@@ -1,7 +1,5 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
@@ -52,7 +50,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
@@ -253,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
@@ -266,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator = self._get_collator_with_removed_columns(
data_collator,
description="training",
)
@@ -308,10 +305,10 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
@@ -333,8 +330,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
group_leader_rank * len(prompts_text) : (
group_leader_rank + 1
)
* len(prompts_text) : self.num_generations
]
@@ -485,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(prompts, completions_text, strict=False):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
@@ -503,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.reward_funcs,
self.reward_processing_classes,
self.reward_func_names,
strict=False,
)
):
with profiling_context(self, reward_func_name):
@@ -511,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
{"messages": p + c}
for p, c in zip(prompts, completions, strict=False)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [
p + c for p, c in zip(prompts, completions, strict=False)
]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
@@ -564,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
"Please ensure that at least one reward function returns a valid reward.",
stacklevel=2,
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the

View File

@@ -5,7 +5,6 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
@@ -15,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
return_outputs=False,
num_items_in_batch=None,
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin

View File

@@ -92,7 +92,7 @@ def get_lora_act_offloading_ctx_manager(
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,

View File

@@ -26,7 +26,6 @@ class DistributedParallelMixin(Trainer):
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"

View File

@@ -70,11 +70,11 @@ class OptimizerMixin(Trainer):
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
lr = optimizer_kwargs["lr"]
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
lr *= self.args.embedding_lr_scale
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
lr = self.args.embedding_lr
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
@@ -143,7 +143,7 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer = create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
@@ -185,17 +185,15 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer

View File

@@ -46,7 +46,7 @@ class SchedulerMixin(Trainer):
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
if self.lr_scheduler is None: # type: ignore
# fmt: on
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
@@ -90,7 +90,7 @@ class SchedulerMixin(Trainer):
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -98,7 +98,7 @@ class SchedulerMixin(Trainer):
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -107,7 +107,7 @@ class SchedulerMixin(Trainer):
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_min_lr(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -133,7 +133,7 @@ class SchedulerMixin(Trainer):
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,

View File

@@ -14,7 +14,6 @@ class AxolotlTrainingMixins:
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
@@ -50,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -26,7 +26,7 @@ class TokenizedPromptDataset(Dataset):
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
def __init__(
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
@@ -99,7 +99,7 @@ class ConstantLengthDataset(IterableDataset):
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
def __init__(
self,
tokenizer,
datasets,

View File

@@ -79,7 +79,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps

View File

@@ -76,7 +76,7 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: dict): # pylint: disable=unused-argument
def register(self, cfg: dict):
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
@@ -104,14 +104,13 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
def pre_model_load(self, cfg: DictDefault):
"""Performs actions before the model is loaded.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
@@ -119,7 +118,6 @@ class BasePlugin:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
@@ -128,7 +126,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
@@ -137,7 +134,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
@@ -146,7 +142,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
@@ -157,7 +152,6 @@ class BasePlugin:
The first non-`None` trainer class returned by a plugin.
"""
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
@@ -166,7 +160,7 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
def get_training_args(self, cfg: DictDefault):
"""
Returns custom training arguments to set on TrainingArgs.
@@ -177,9 +171,7 @@ class BasePlugin:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):
"""
Returns a custom class for the collator.
@@ -191,7 +183,6 @@ class BasePlugin:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -203,7 +194,6 @@ class BasePlugin:
The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self,
cfg: DictDefault,
@@ -223,7 +213,6 @@ class BasePlugin:
The created learning rate scheduler.
"""
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
@@ -238,7 +227,6 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
@@ -254,7 +242,6 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
@@ -263,7 +250,7 @@ class BasePlugin:
model: The loaded model.
"""
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
def post_train_unload(self, cfg: DictDefault):
"""Performs actions after training is complete and the model is unloaded.
Args:
@@ -311,7 +298,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.

View File

@@ -50,15 +50,9 @@ def merge_input_args():
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, globals(), namespace
)
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
exec(dynamic_input, globals(), namespace) # nosec B102
AxolotlInputConfig = namespace["AxolotlInputConfig"]
AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
@@ -74,7 +68,7 @@ def merge_training_args() -> Type:
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
@@ -93,11 +87,7 @@ def merge_training_args() -> Type:
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102
AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage

View File

@@ -18,6 +18,7 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
from functools import partial
@@ -28,13 +29,13 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
from .args import CutCrossEntropyArgs as CutCrossEntropyArgs
LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)
@@ -106,9 +107,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
def patch_generic(maybe_model, patch_options, model_type: str):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -121,12 +120,10 @@ class CutCrossEntropyPlugin(BasePlugin):
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -15,6 +15,7 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator

View File

@@ -7,7 +7,7 @@ from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .args import GrokfastArgs as GrokfastArgs
from .optimizer import gradfilter_ema
LOG = get_logger(__name__)
@@ -24,12 +24,10 @@ class GrokfastCallbackHandler(TrainerCallback):
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
def on_train_begin(self, *args_, **kwargs):
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
def on_pre_optimizer_step(self, args_, state, control, **kwargs):
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control

View File

@@ -1,7 +1,6 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional

View File

@@ -15,6 +15,7 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
@@ -22,7 +23,7 @@ from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
from .args import KDArgs as KDArgs
class KDPlugin(BasePlugin):

View File

@@ -15,6 +15,7 @@
"""
Plugin args for KD support.
"""
from dataclasses import dataclass
from enum import Enum
@@ -26,8 +27,8 @@ class InferenceServerType(str, Enum):
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
vllm = "vllm"
sglang = "sglang"
class KDArgs(BaseModel):

View File

@@ -19,9 +19,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback):
self.trainer = trainer
def on_step_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
def on_step_end(self, args, state, control, **kwargs):
# cosine decay temperature over the max steps
progress = state.global_step / state.max_steps

View File

@@ -15,6 +15,7 @@
"""
Chat template prompt strategy loader with KD support
"""
import logging
from typing import Any, Dict
@@ -192,7 +193,6 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
"""
Transform logprobs to target format for KD training
"""
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
@@ -240,7 +240,7 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"]
logprobs, sample["target_token_ids"], strict=False
):
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
@@ -299,7 +299,7 @@ class KDStrategyLoader(StrategyLoader):
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
def _get_strategy_cls(self, cfg):
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
@@ -319,7 +319,7 @@ class KDStrategyLoaderV2(KDStrategyLoader):
Load KD chat template datasets with pre-tokenized logprob data
"""
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
def _get_strategy_cls(self, cfg):
return ChatTemplateStrategyWithKDv2

View File

@@ -37,7 +37,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
"""
# pylint: disable=duplicate-code
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
@@ -72,7 +71,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
// self.pad_to_multiple_of
) * self.pad_to_multiple_of
for f in features: # pylint: disable=invalid-name
for f in features:
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
f[feature_name] = (
@@ -101,7 +100,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
if has_teacher_data:
# Extract and remove from features
for f in features: # pylint: disable=invalid-name
for f in features:
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask"))
@@ -117,24 +116,25 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
padded_teacher_mask_list = []
for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list
target_logprobs_list,
target_token_ids_list,
target_mask_list,
strict=False,
):
t_logprobs_padded = []
t_ids_padded = []
t_mask_padded = []
for lp, ids, mask in zip( # pylint: disable=invalid-name
t_logprobs, t_ids, t_mask
):
for lp, ids, mask in zip(t_logprobs, t_ids, t_mask, strict=False):
lp_len = len(lp)
if lp_len < max_k:
# Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
lp = lp + [-1e9] * pad_len
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
else:
lp = lp[:max_k] # pylint: disable=invalid-name
lp = lp[:max_k]
ids = ids[:max_k]
mask = mask[:max_k]
@@ -216,9 +216,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
features
):
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
@@ -255,9 +253,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor)
):
if isinstance(
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
if isinstance(feat[field_name][0], (dict, str)):
continue
arr = np.array(feat[field_name])
arrays.append(arr)

View File

@@ -144,7 +144,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
api_data, batch_input_ids, labels, strict=False
):
current_target_logprobs = []
current_target_token_ids = []
@@ -165,7 +165,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
assert len(seq_input_ids) == len(input_top_logprobs)
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
range(len(seq_input_ids)), seq_input_ids, seq_labels, strict=False
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
@@ -202,7 +202,8 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
list(row)
for row in zip(*pos_top_logprobs_data, strict=False)
]
# Ensure correct length (top_k)
@@ -317,7 +318,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
choices, batch_input_ids, labels, strict=False
):
# seq_input_ids: List[int]
# seq_labels: List[int]
@@ -342,7 +343,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
seq_len = len(seq_input_ids)
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
for i, _, label in zip(
range(seq_len), seq_input_ids, seq_labels, strict=False
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
@@ -424,7 +427,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(max(0, seq_len - len(current_target_logprobs))):
for _ in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)

View File

@@ -197,7 +197,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
CHUNK_SIZE = chunk_size
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
@@ -298,8 +298,8 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
B, N, D = student_input.shape
K = target_token_ids.shape[-1]
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])

View File

@@ -40,10 +40,9 @@ def kldiv_forward_llama_like(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None

View File

@@ -15,6 +15,7 @@
"""
loss for top_k KL divergence
"""
import torch
from torch import nn
@@ -117,7 +118,6 @@ class ChunkedTopKKDLoss(nn.Module):
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor:
# 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
@@ -131,7 +131,11 @@ class ChunkedTopKKDLoss(nn.Module):
# 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
student_logits_chunks,
token_ids_chunks,
logprobs_chunks,
mask_chunks,
strict=False,
):
# We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only.

View File

@@ -21,7 +21,6 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -18,6 +18,7 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
from .args import LigerArgs
from .plugin import LigerPlugin

View File

@@ -41,7 +41,6 @@ def lce_forward(
This is useful when using packed tensor format (single dimension for batch and sequence length).
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
@@ -181,7 +180,7 @@ def patch_lce_forward(
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -2,8 +2,6 @@
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union
import torch

View File

@@ -2,8 +2,6 @@
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch

View File

@@ -46,7 +46,6 @@ def lce_forward(
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
@@ -78,9 +77,7 @@ def lce_forward(
hidden_states = outputs[0]
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
raise Exception( # pylint: disable=broad-exception-raised
"Liger Kernel does not support pretraining_tp!!"
)
raise Exception("Liger Kernel does not support pretraining_tp!!")
logits = None
loss = None
@@ -128,7 +125,7 @@ def apply_liger_kernel_to_llama4(
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -144,15 +141,15 @@ def apply_liger_kernel_to_llama4(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import
import transformers.models.llama4.modeling_llama4 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"]
@@ -165,7 +162,7 @@ def apply_liger_kernel_to_llama4(
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper

View File

@@ -43,7 +43,6 @@ def lce_forward(
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
@@ -113,9 +112,8 @@ def apply_liger_kernel_to_qwen3(
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
**kwargs,
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -130,15 +128,15 @@ def apply_liger_kernel_to_qwen3(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
import transformers.models.qwen3.modeling_qwen3 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]

View File

@@ -45,7 +45,6 @@ def lce_forward(
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
@@ -135,9 +134,8 @@ def apply_liger_kernel_to_qwen3_moe(
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
**kwargs,
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -152,15 +150,15 @@ def apply_liger_kernel_to_qwen3_moe(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
@@ -174,7 +172,7 @@ def apply_liger_kernel_to_qwen3_moe(
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper

View File

@@ -7,7 +7,7 @@ import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
from .args import LMEvalArgs as LMEvalArgs
class LMEvalPlugin(BasePlugin):
@@ -20,7 +20,6 @@ class LMEvalPlugin(BasePlugin):
def post_train_unload(self, cfg):
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,

View File

@@ -99,7 +99,6 @@ def lm_eval(config: str, cloud: Optional[str] = None):
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,

View File

@@ -23,7 +23,7 @@ import requests
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
from .args import SpectrumArgs as SpectrumArgs
LOG = get_logger(__name__)
@@ -46,7 +46,7 @@ def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
"^lm_head.weight$",
"^model.embed_tokens.weight$",
]
for layer_type, layer_names in top_layers_by_type.items():
for _, layer_names in top_layers_by_type.items():
for layer_name in layer_names:
unfrozen_parameters.append(layer_name)
return unfrozen_parameters
@@ -84,7 +84,7 @@ class SpectrumPlugin(BasePlugin):
snr_data = json.load(fin)
except FileNotFoundError:
pass
except Exception as exc: # pylint: disable=broad-exception-caught
except Exception as exc:
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data:

View File

@@ -15,6 +15,7 @@
"""
Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator

View File

@@ -5,8 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch
import triton
import triton.language as tl

View File

@@ -7,8 +7,6 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable
import torch

View File

@@ -1,7 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes
import bitsandbytes as bnb

View File

@@ -99,7 +99,6 @@ def _swiglu_bwd_kernel(
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
# pylint: disable=unnecessary-lambda-assignment
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
@@ -128,7 +127,6 @@ def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
return out
# pylint: disable=unnecessary-lambda-assignment
def swiglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.loaders module"""
# pylint: disable=unused-import
# flake8: noqa
from .adapter import load_adapter, load_lora

View File

@@ -28,14 +28,12 @@ LOG = get_logger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
def temp_to_method(self, *args, **kwargs):
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
@@ -43,10 +41,8 @@ def setup_quantized_peft_meta_for_training(model: torch.nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
param.quant_state.to = param.quant_state._orig_to
param.quant_state._orig_to = None
def find_all_linear_names(model):

View File

@@ -102,7 +102,7 @@ class ModelLoader:
*,
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
**kwargs,
):
"""Initializes the ModelLoader.
@@ -134,7 +134,7 @@ class ModelLoader:
# Init model config
self.model_config = load_model_config(cfg)
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
self.auto_model_loader = AutoModelForCausalLM
# Initialize the patch manager
self.patch_manager = PatchManager(
@@ -607,27 +607,19 @@ class ModelLoader:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
self.model_config._attn_implementation = "flex_attention"
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
self.model_config._attn_implementation = "sdpa"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
self.model_config._attn_implementation = "eager"
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
@@ -767,7 +759,7 @@ class ModelLoader:
)
elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
MambaLMHeadModel = fix_mamba_attn_for_loss()
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
self.model_kwargs["device"] = torch.cuda.current_device()
@@ -816,7 +808,6 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
# pylint: disable=protected-access
if self.cfg.tensor_parallel_size > 1:
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
# TODO(wing): remove once 4.54.1 is released

View File

@@ -277,6 +277,14 @@ class PatchManager:
has_remote_code=has_remote_code,
)
if self.cfg.sample_packing:
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
apply_multipack_dataloader_patch,
)
LOG.info("Applying multipack dataloader patch for sample packing...")
apply_multipack_dataloader_patch()
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (

Some files were not shown because too many files have changed in this diff Show More