Compare commits

..

12 Commits

Author SHA1 Message Date
Dan Saunders
d3bea3a2eb broken 2025-08-25 16:51:36 +00:00
Dan Saunders
2e2302aae3 remove unused 2025-08-25 15:46:25 +00:00
Dan Saunders
3a35076513 seems to be working? 2025-08-25 14:22:32 +00:00
Dan Saunders
79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00
Dan Saunders
eea7a006e1 make multipack sampler patch explicit (#3096)
* make multipack sampler patch explicit

* combining
2025-08-22 14:29:10 -04:00
Wing Lian
ab4d604a8f upgrade peft for 0.17.1 (#3094)
* upgrade peft to 0.17.1

* upgrade for transformers too
2025-08-22 07:26:30 -04:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
310 changed files with 11477 additions and 13349 deletions

View File

@@ -1,3 +1,3 @@
[bandit] [bandit]
exclude = tests exclude = tests
skips = B101,B615 skips = B101,B615,B102,B110

View File

@@ -12,5 +12,6 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
drafts: false drafts: false
auto_incremental_review: true
chat: chat:
auto_reply: true auto_reply: true

View File

@@ -1,5 +0,0 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -1,4 +0,0 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,22 +10,12 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 25.1.0 rev: v0.12.9
hooks: hooks:
- id: black - id: ruff
- repo: https://github.com/pycqa/isort args: [--fix]
rev: 6.0.1 - id: ruff-format
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1 rev: v1.17.1
hooks: hooks:

View File

@@ -1,15 +0,0 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -2,8 +2,6 @@
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code
import os import os
import pathlib import pathlib
import tempfile import tempfile
@@ -63,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess. # Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)
@app.function( @app.function(

View File

@@ -1,7 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os import os
import pathlib import pathlib
import tempfile import tempfile
@@ -70,4 +68,4 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess. # Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)

View File

@@ -47,7 +47,6 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel.""" """Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel) return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any: def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations.""" """Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+) # Handle Annotated types (Python 3.9+)
@@ -124,7 +123,6 @@ class QuartoGenerator:
return field_type return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type( def _extract_all_pydantic_models_from_type(
self, field_type self, field_type
) -> list[type[BaseModel]]: ) -> list[type[BaseModel]]:
@@ -318,7 +316,6 @@ class QuartoGenerator:
return all_groups return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source( def _extract_field_groups_from_source(
self, model_class: type[BaseModel] self, model_class: type[BaseModel]
) -> list[dict]: ) -> list[dict]:
@@ -503,7 +500,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema() nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {}) nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", []) nested_required = nested_schema.get("required", [])
except Exception: # pylint: disable=broad-exception-caught except Exception:
# Fallback: use model fields directly # Fallback: use model fields directly
nested_properties = {} nested_properties = {}
nested_required = [] nested_required = []
@@ -607,7 +604,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema() schema = model_class.model_json_schema()
properties = schema.get("properties", {}) properties = schema.get("properties", {})
required = schema.get("required", []) required = schema.get("required", [])
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e:
print( print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead." f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
) )

File diff suppressed because it is too large Load Diff

View File

@@ -41,6 +41,12 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
``` ```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -61,9 +67,23 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model ### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM. for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -1,57 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 10
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,58 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps
eval_strategy: steps
save_steps: 500
eval_steps: 500
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,3 +26,34 @@ include-package-data = true
[tool.setuptools.cmdclass] [tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand" build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -13,8 +13,8 @@ liger-kernel==0.6.1
packaging==23.2 packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.17.0 peft>=0.17.0
transformers==4.55.2 transformers==4.55.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.10.0 accelerate==1.10.0
datasets==4.0.0 datasets==4.0.0

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not field_messages: if not field_messages:
raise ValueError( raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}' f"No conversation field found in dataset: {', '.join(feature_keys)}"
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not message_property_mappings["role"]: if not message_property_mappings["role"]:
raise ValueError( raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}' f"No role field found in messages: {', '.join(message_fields)}"
) )
for key in ["content", "text", "value"]: for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not message_property_mappings["content"]: if not message_property_mappings["content"]:
raise ValueError( raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}' f"No content field found in messages: {', '.join(message_fields)}"
) )
ds_cfg["message_property_mappings"] = message_property_mappings ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -1,11 +1,10 @@
# noqa # noqa
# pylint: skip-file
import sys import sys
try: try:
import torch import torch
except ImportError: except ImportError as error:
raise ImportError("Install torch via `pip install torch`") raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:] use_uv = "--uv" in sys.argv[1:]

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.2"], "flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.2", "flash-attn==2.8.3",
"ring-flash-attn>=0.1.7", "ring-flash-attn>=0.1.7",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art(): def print_axolotl_text_art():
"""Prints axolotl ASCII art.""" """Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement global HAS_PRINTED_LOGO
if HAS_PRINTED_LOGO: if HAS_PRINTED_LOGO:
return return
if is_main_process(): if is_main_process():

View File

@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603 if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env cmd.split(), cwd=run_folder, env=new_env
): ):
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)
# Commit writes to volume. # Commit writes to volume.
if volumes: if volumes:
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0" docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -130,7 +130,6 @@ class ModalCloud(Cloud):
res = [] res = []
if self.config.secrets: if self.config.secrets:
for key in self.config.get("secrets", []): for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str): if isinstance(key, str):
if val := os.environ.get(key, ""): if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val})) res.append(modal.Secret.from_dict({key: val}))
@@ -177,8 +176,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True): with self.app.run(detach=True):
modal_fn.remote( modal_fn.remote(
config_yaml, config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args, *args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs, **kwargs,
) )
@@ -187,7 +186,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout) return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements def get_train_gpu(self):
count = self.config.gpu_count or 1 count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s" family = self.config.gpu.lower() or "l40s"
@@ -200,7 +199,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]: if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count) return modal.gpu.A10G(count=count)
if family == "h100": if family == "h100":
return modal.gpu.H100(count=count) return f"H100:{count}"
if family == "t4": if family == "t4":
return modal.gpu.T4(count=count) return modal.gpu.T4(count=count)
if family == "l4": if family == "l4":
@@ -277,7 +276,7 @@ def _train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None, launcher_args: list[str] | None = None,
volumes=None, volumes=None,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True) Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:

View File

@@ -210,7 +210,7 @@ def load_cfg(
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722 except:
gpu_version = None gpu_version = None
prepare_plugins(cfg) prepare_plugins(cfg)

View File

@@ -28,7 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments. cli_args: CLI arguments.
""" """
# pylint: disable=duplicate-code
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token() check_user_token()
@@ -49,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -35,7 +35,7 @@ def get_multi_line_input() -> str:
instruction = "" instruction = ""
for line in sys.stdin: for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join instruction += line
return instruction return instruction
@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template": elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
@@ -167,7 +167,6 @@ def do_inference_gradio(
if not instruction: if not instruction:
return return
if prompter_module: if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next( prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n")) prompter_module().build_prompt(instruction=instruction.strip("\n"))
) )
@@ -252,7 +251,7 @@ def do_cli(
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs) parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs) parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -1,7 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os import os
import subprocess # nosec B404 import subprocess # nosec B404
from typing import Literal, Optional from typing import Literal, Optional

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading.""" """A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument def commit_tensor(self, read_item, tensor):
tensor.copy_(tensor.to(torch.bfloat16)) tensor.copy_(tensor.to(torch.bfloat16))
@@ -59,10 +59,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {} state_dict: Dict = {}
save_path_ = Path(save_path) save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True) save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access dist_cp_format_utils._load_state_dict(
state_dict, state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir), storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access planner=BFloat16CastPlanner(),
no_dist=True, no_dist=True,
) )
@@ -191,7 +191,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -73,7 +73,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained( AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True
) )
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841 except Exception: # nosec B110
pass pass
# fmt: on # fmt: on
@@ -95,9 +95,10 @@ def do_cli(
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1" os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs) is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -59,7 +59,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -65,7 +65,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type) field_type = _strip_optional_type(field.type)
if field_type == bool: if field_type is bool:
field_name = field.name.replace("_", "-") field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(
@@ -103,7 +103,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation) field_type = _strip_optional_type(field.annotation)
if field_type == bool: if field_type is bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(

View File

@@ -3,11 +3,12 @@
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from typing import Any
def generate_sweep_configs( def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list] base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]: ) -> list[dict[str, Any]]:
""" """
Recursively generates all possible configurations by applying sweeps to the base config. Recursively generates all possible configurations by applying sweeps to the base config.
@@ -48,7 +49,10 @@ def generate_sweep_configs(
new_config = {} new_config = {}
# new_config = deepcopy(base_config) # new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters # Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
for param_name, param_value in full_combo.items(): for param_name, param_value in full_combo.items():
new_config[param_name] = param_value new_config[param_name] = param_value
print(new_config) print(new_config)
@@ -57,7 +61,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations # If no paired values, just use regular combinations
# new_config = deepcopy(base_config) # new_config = deepcopy(base_config)
new_config = {} new_config = {}
for param_name, param_value in zip(param_names, reg_combo): for param_name, param_value in zip(param_names, reg_combo, strict=False):
new_config[param_name] = param_value new_config[param_name] = param_value
print(new_config) print(new_config)
all_combinations.append(new_config) all_combinations.append(new_config)

View File

@@ -4,6 +4,7 @@ import os
import subprocess # nosec import subprocess # nosec
import sys import sys
import tempfile import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal from typing import Any, Iterator, Literal
import yaml import yaml
@@ -88,8 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations # Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config) permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1 is_group = len(permutations) > 1
for permutation in permutations: base_output_dir = base_config.get("output_dir", "./model-out")
# pylint: disable=consider-using-with for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
temp_file = tempfile.NamedTemporaryFile( temp_file = tempfile.NamedTemporaryFile(
mode="w", mode="w",
suffix=".yaml", suffix=".yaml",

View File

@@ -39,7 +39,7 @@ def do_vllm_serve(
model = cfg.base_model model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1 tensor_parallel_size = 1
data_parallel_size = 1 data_parallel_size = 1
@@ -68,7 +68,6 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
) )
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments( vllm_script_args = AxolotlScriptArguments(
model=model, model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass
from datasets import Dataset from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.data import prepare_datasets, prepare_preference_datasets

View File

@@ -67,9 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer self.jsonl_serializer = jsonl_serializer
def convert( def convert(self, input_file_path, output_file_path):
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
content = self.file_reader.read(input_file_path) content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content) data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -84,9 +84,7 @@ def create_causal_mask(
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None: if attention_mask is not None:
def causal_doc_mask_mod( def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
""" """
Defines the logic of a block causal mask by combining both a standard causal mask Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask. and a block diagonal document mask.
@@ -103,9 +101,7 @@ def create_causal_mask(
mask_factory_function = causal_doc_mask_mod mask_factory_function = causal_doc_mask_mod
else: else:
mask_factory_function = causal_mask_function mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC) # Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = ( allow_is_causal_skip = (

View File

@@ -44,7 +44,7 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
with suppress(ImportError): with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
@@ -260,14 +260,14 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon": if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module from axolotl.contribs.mit.muon import (
MuonOptimizerFactory, MuonOptimizerFactory,
) )
optimizer_cls = MuonOptimizerFactory optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion": elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module from axolotl.contribs.mit.dion import (
DionOptimizerFactory, DionOptimizerFactory,
) )
@@ -414,12 +414,8 @@ class TrainerBuilderBase(abc.ABC):
def _configure_torch_compile(self, training_args_kwargs: dict): def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None): if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access torch._dynamo.config.suppress_errors = True
True torch._dynamo.config.accumulated_cache_size_limit = 256
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = ( training_args_kwargs["torch_compile_backend"] = (

View File

@@ -10,7 +10,6 @@ import transformers
from transformers import ( from transformers import (
DataCollatorWithFlattening, DataCollatorWithFlattening,
EarlyStoppingCallback, EarlyStoppingCallback,
Trainer,
) )
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -345,16 +344,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_args_cls = AxolotlPRMConfig training_args_cls = AxolotlPRMConfig
else: else:
training_args_cls = AxolotlTrainingArguments training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls(
**training_arguments_kwargs, **training_arguments_kwargs,
) )
training_args = self.hook_post_create_training_args(training_args) training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names # unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init training_args.run_name = None
None
)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, # True/"longest" is the default "padding": True, # True/"longest" is the default
@@ -386,11 +383,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs, **data_collator_kwargs,
) )
sig = inspect.signature(trainer_cls) sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters: elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer trainer_kwargs["tokenizer"] = self.tokenizer
if ( if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None and self.cfg.datasets is not None

View File

@@ -168,16 +168,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if plugin_training_args: if plugin_training_args:
training_args_kwargs.update(plugin_training_args) training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls(
logging_first_step=True, logging_first_step=True,
**training_args_kwargs, **training_args_kwargs,
) )
# unset run_name so wandb sets up experiment names # unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init training_args.run_name = None
None
)
return training_args, trainer_kwargs return training_args, trainer_kwargs

View File

@@ -10,7 +10,7 @@ from .shared import wrap_tools
def format_message( def format_message(
message: Messages, message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument message_index: Optional[int] = None,
) -> Messages: ) -> Messages:
if message.is_chat_formatted: if message.is_chat_formatted:
return message return message

View File

@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
Message roles for the system, user, assistant, and tools Message roles for the system, user, assistant, and tools
""" """
system = "system" # pylint: disable=invalid-name system = "system"
user = "user" # pylint: disable=invalid-name user = "user"
assistant = "assistant" # pylint: disable=invalid-name assistant = "assistant"
tool = "tool" # pylint: disable=invalid-name tool = "tool"
ipython = ( # pylint: disable=invalid-name ipython = (
# for responses from builtin tools # for responses from builtin tools
"ipython" "ipython"
) )
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
Message content types for text, image, audio, tool calls, and tool responses Message content types for text, image, audio, tool calls, and tool responses
""" """
special_token = "special_token" # pylint: disable=invalid-name # nosec B105 special_token = "special_token" # nosec B105
text = "text" # pylint: disable=invalid-name text = "text"
image = "image" # pylint: disable=invalid-name image = "image"
audio = "audio" # pylint: disable=invalid-name audio = "audio"
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant tool_call = "tool_call"
tool_response = "tool_response" # pylint: disable=invalid-name tool_response = "tool_response"
class SpecialToken(str, Enum): class SpecialToken(str, Enum):
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
Special tokens for beginning of string and end of string Special tokens for beginning of string and end of string
""" """
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 bos_token = "bos_token" # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 eos_token = "eos_token" # nosec B105
class ToolCallFunction(BaseModel): class ToolCallFunction(BaseModel):
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
name: str name: str
arguments: dict[str, Union[str, int]] arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name id: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments} data = {"name": self.name, "arguments": self.arguments}
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
name: str name: str
content: Union[str, dict[str, Union[str, int, float]]] content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name id: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
data = {"name": self.name, "content": self.content} data = {"name": self.name, "content": self.content}

View File

@@ -1,23 +1,17 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the
dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping
def chat_message_transform_builder( # pylint: disable=dangerous-default-value def chat_message_transform_builder(
train_on_inputs=False, train_on_inputs=False,
conversations_field: str = "conversations", conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: Union[str, list[str]] = [ message_field_content: str | list[str] | None = None, # commonly "content"
"value", message_field_training: str | list[str] | None = None, # commonly "weight"
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
): ):
"""Builds a transform that takes a row from the dataset and converts it to a Chat """Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -39,6 +33,12 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value
A function that takes a list of conversations and returns a list of messages. A function that takes a list of conversations and returns a list of messages.
""" """
if message_field_training is None:
message_field_training = ["train", "weight"]
if message_field_content is None:
message_field_content = ["value", "text", "content"]
if message_field_role is None:
message_field_role = ["role", "from"]
message_field_role = ( message_field_role = (
[message_field_role] [message_field_role]
if isinstance(message_field_role, str) if isinstance(message_field_role, str)

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers""" """Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .base import AxolotlTrainer from .base import AxolotlTrainer

View File

@@ -1,7 +1,5 @@
"""Module for customized trainers""" """Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations from __future__ import annotations
import os import os
@@ -82,9 +80,7 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict( self._stored_metrics = defaultdict(lambda: defaultdict(list))
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -287,9 +283,9 @@ class AxolotlTrainer(
# fmt: off # fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers: if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"): if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
else: else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init self._eval_dataloaders = {dataloader_key: dataloader}
# fmt: on # fmt: on
return self.accelerator.prepare(dataloader) return self.accelerator.prepare(dataloader)
@@ -445,7 +441,7 @@ class AxolotlTrainer(
model, model,
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument num_items_in_batch=None,
): ):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs, inputs,
@@ -526,9 +522,7 @@ class AxolotlTrainer(
accelerator_config = self.args.accelerator_config.to_dict() accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False) use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state: if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access AcceleratorState._reset_state(reset_partial_state=True)
reset_partial_state=True
)
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
@@ -542,7 +536,6 @@ class AxolotlTrainer(
): ):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True self.accelerator.state.fsdp_plugin.limit_all_gathers = True
# pylint: disable=unused-argument
def additional_accelerator_args( def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -575,26 +568,9 @@ class AxolotlTrainer(
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
# Add reduced stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items():
for key, metric_data in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item()
values = torch.tensor(metric_data["values"])
reduction_type = metric_data["reduction"]
if reduction_type == "mean":
logs[key] = values.mean().item()
elif reduction_type == "min":
logs[key] = values.min().item()
elif reduction_type == "max":
logs[key] = values.max().item()
elif reduction_type == "sum":
logs[key] = values.sum().item()
else:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(logs[key], 4)
if is_main_process(): if is_main_process():
# Add memory usage # Add memory usage
@@ -611,27 +587,10 @@ class AxolotlTrainer(
return super().log(logs, start_time) return super().log(logs, start_time)
def store_metrics( def store_metrics(
self, self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None: ) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items(): for key, value in metrics.items():
if isinstance(value, tuple): self._stored_metrics[train_eval][key].append(value)
metric_value, metric_reduction = value
else:
metric_value, metric_reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(metric_value)
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
def _save_checkpoint(self, model, trial, **kwargs): def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey

View File

@@ -101,11 +101,11 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss: if self.args.dpo_norm_loss:
# fmt: off # fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on # fmt: on
# concatenated_forward handles avg token logprob for ipo case already # concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model) res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init self.loss_type = loss_type
return res return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model) return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -128,9 +128,7 @@ class GRPOStrategy:
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
def set_trainer_args( def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = [] trainer_args = []
if cfg.trl and cfg.trl.reward_funcs: if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = [] reward_funcs = []
@@ -151,7 +149,7 @@ class GRPOStrategy:
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument def get_collator(cls, *args, **kwargs):
# No data collation is needed in GRPO, handled by trl's trainer __init__ # No data collation is needed in GRPO, handled by trl's trainer __init__
return None return None

View File

@@ -1,7 +1,5 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)""" """Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings import warnings
from functools import partial from functools import partial
from typing import Any from typing import Any
@@ -52,7 +50,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.monkeypatch.ring_attn import get_ring_attn_group from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available(): if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig from peft import PeftConfig
@@ -253,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training""" """Get dataloader for training"""
train_dataset = self.train_dataset train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing # Handle dataset preprocessing
@@ -266,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
train_dataset, description="training" train_dataset, description="training"
) )
else: else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init self.data_collator = self._get_collator_with_removed_columns(
data_collator, data_collator,
description="training", description="training",
) )
@@ -308,10 +305,10 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.args.use_vllm: if self.args.use_vllm:
# First, have main process load weights if needed # First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm() self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
@@ -333,8 +330,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Extract prompts from this SP group, accounting for num_generations duplicates # Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group # We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[ group_prompts = all_prompts_text[
group_leader_rank group_leader_rank * len(prompts_text) : (
* len(prompts_text) : (group_leader_rank + 1) group_leader_rank + 1
)
* len(prompts_text) : self.num_generations * len(prompts_text) : self.num_generations
] ]
@@ -485,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
completions = [] completions = []
for prompt, completion in zip(prompts, completions_text): for prompt, completion in zip(prompts, completions_text, strict=False):
bootstrap = ( bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
) )
@@ -503,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.reward_funcs, self.reward_funcs,
self.reward_processing_classes, self.reward_processing_classes,
self.reward_func_names, self.reward_func_names,
strict=False,
) )
): ):
with profiling_context(self, reward_func_name): with profiling_context(self, reward_func_name):
@@ -511,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): # Module instead of PretrainedModel for compat with compiled models ): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
messages = [ messages = [
{"messages": p + c} for p, c in zip(prompts, completions) {"messages": p + c}
for p, c in zip(prompts, completions, strict=False)
] ]
texts = [ texts = [
apply_chat_template(x, reward_processing_class)["text"] apply_chat_template(x, reward_processing_class)["text"]
for x in messages for x in messages
] ]
else: else:
texts = [p + c for p, c in zip(prompts, completions)] texts = [
p + c for p, c in zip(prompts, completions, strict=False)
]
reward_inputs = reward_processing_class( reward_inputs = reward_processing_class(
text=texts, text=texts,
return_tensors="pt", return_tensors="pt",
@@ -564,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
row_reward_kwargs["completion"] = completions[nan_row_idx] row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn( warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward." "Please ensure that at least one reward function returns a valid reward.",
stacklevel=2,
) )
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the

View File

@@ -5,7 +5,6 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation""" """Mamba specific trainer to handle loss calculation"""
@@ -15,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
self, self,
model, model,
inputs, inputs,
return_outputs=False, # pylint: disable=unused-argument return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument num_items_in_batch=None,
): ):
input_ids = inputs.pop("input_ids") input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits lm_logits = model(input_ids).logits

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers.mixins""" """Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin from .activation_checkpointing import ActivationOffloadingMixin

View File

@@ -92,7 +92,7 @@ def get_lora_act_offloading_ctx_manager(
`contextlib.ContextDecorator`: `contextlib.ContextDecorator`:
Activation offloading context manager for the model. Activation offloading context manager for the model.
""" """
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations( activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory, use_pin_memory=use_pin_memory,
use_streams=use_streams, use_streams=use_streams,

View File

@@ -26,7 +26,6 @@ class DistributedParallelMixin(Trainer):
self.accelerator.distributed_type == "FSDP" self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None and self.accelerator.state.fsdp_plugin is None
): ):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP # handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU" self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU" self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"

View File

@@ -70,11 +70,11 @@ class OptimizerMixin(Trainer):
} }
) )
if params["embeddings"]: if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name lr = optimizer_kwargs["lr"]
if self.args.embedding_lr_scale: if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name lr *= self.args.embedding_lr_scale
elif self.args.embedding_lr: elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name lr = self.args.embedding_lr
optimizer_grouped_parameters.append( optimizer_grouped_parameters.append(
{ {
"params": list(params["embeddings"].values()), "params": list(params["embeddings"].values()),
@@ -143,7 +143,7 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding = getattr( loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6 self.args, "loraplus_lr_embedding", 1e-6
) )
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = create_loraplus_optimizer(
opt_model, opt_model,
optimizer_cls, optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio, loraplus_lr_ratio=loraplus_lr_ratio,
@@ -185,17 +185,15 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters() p.data_ptr(): p.numel() for p in module.parameters()
}.values() }.values()
) )
LOG.info(f"skipped {module}: {skipped/2**20}M params") LOG.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override( manager.register_module_override(
module, "weight", {"optim_bits": 32} module, "weight", {"optim_bits": 32}
) )
LOG.debug(f"bitsandbytes: will optimize {module} in fp32") LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params") LOG.info(f"skipped: {skipped / 2**20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer(self.optimizer)
self.optimizer
)
return self.optimizer return self.optimizer

View File

@@ -46,7 +46,7 @@ class SchedulerMixin(Trainer):
) )
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition if self.lr_scheduler is None: # type: ignore
# fmt: on # fmt: on
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler( lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
@@ -90,7 +90,7 @@ class SchedulerMixin(Trainer):
LOG.warning( LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") "Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
@@ -98,7 +98,7 @@ class SchedulerMixin(Trainer):
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
@@ -107,7 +107,7 @@ class SchedulerMixin(Trainer):
) )
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_min_lr(
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
@@ -133,7 +133,7 @@ class SchedulerMixin(Trainer):
) )
if not self.lr_scheduler: if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer) super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = JaggedLRRestartScheduler(
optimizer, optimizer,
self.lr_scheduler, self.lr_scheduler,
self.args.jagged_restart_steps, self.args.jagged_restart_steps,

View File

@@ -14,7 +14,6 @@ class AxolotlTrainingMixins:
Mixin class for the Axolotl training args. Mixin class for the Axolotl training args.
""" """
# pylint: disable=duplicate-code
model_type: Optional[str] = field( model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."} default=None, metadata={"help": "HF model configuration model_type."}
) )

View File

@@ -26,7 +26,7 @@ class TokenizedPromptDataset(Dataset):
keep_in_memory: Whether to keep the tokenized dataset in memory. keep_in_memory: Whether to keep the tokenized dataset in memory.
""" """
def __init__( # pylint: disable=super-init-not-called def __init__(
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset, dataset: Dataset,
@@ -99,7 +99,7 @@ class ConstantLengthDataset(IterableDataset):
seq_length: Length of token sequences to return. seq_length: Length of token sequences to return.
""" """
def __init__( # pylint: disable=super-init-not-called def __init__(
self, self,
tokenizer, tokenizer,
datasets, datasets,

View File

@@ -79,7 +79,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg) model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets # Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps

View File

@@ -76,7 +76,7 @@ class BasePlugin:
def __init__(self): def __init__(self):
"""Initializes the BasePlugin.""" """Initializes the BasePlugin."""
def register(self, cfg: dict): # pylint: disable=unused-argument def register(self, cfg: dict):
"""Registers the plugin with the given configuration as an unparsed dict. """Registers the plugin with the given configuration as an unparsed dict.
Args: Args:
@@ -104,14 +104,13 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset. dataset_meta: The metadata for the training dataset.
""" """
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument def pre_model_load(self, cfg: DictDefault):
"""Performs actions before the model is loaded. """Performs actions before the model is loaded.
Args: Args:
cfg: The configuration for the plugin. cfg: The configuration for the plugin.
""" """
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied. """Performs actions after the model is built/loaded, but before any adapters are applied.
@@ -119,7 +118,6 @@ class BasePlugin:
cfg: The configuration for the plugin. cfg: The configuration for the plugin.
""" """
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded. """Performs actions before LoRA weights are loaded.
@@ -128,7 +126,6 @@ class BasePlugin:
model: The loaded model. model: The loaded model.
""" """
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded. """Performs actions after LoRA weights are loaded.
@@ -137,7 +134,6 @@ class BasePlugin:
model: The loaded model. model: The loaded model.
""" """
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded. """Performs actions after the model is loaded.
@@ -146,8 +142,7 @@ class BasePlugin:
model: The loaded model. model: The loaded model.
""" """
# pylint: disable=unused-argument def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Returns a custom class for the trainer. """Returns a custom class for the trainer.
Args: Args:
@@ -157,7 +152,6 @@ class BasePlugin:
The first non-`None` trainer class returned by a plugin. The first non-`None` trainer class returned by a plugin.
""" """
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created. """Performs actions after the trainer is created.
@@ -166,7 +160,7 @@ class BasePlugin:
trainer: The trainer object for training. trainer: The trainer object for training.
""" """
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument): def get_training_args(self, cfg: DictDefault):
""" """
Returns custom training arguments to set on TrainingArgs. Returns custom training arguments to set on TrainingArgs.
@@ -177,9 +171,7 @@ class BasePlugin:
object: dict containing the training arguments. object: dict containing the training arguments.
""" """
def get_collator_cls_and_kwargs( def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
""" """
Returns a custom class for the collator. Returns a custom class for the collator.
@@ -191,7 +183,6 @@ class BasePlugin:
class: The class for the collator. class: The class for the collator.
""" """
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training. """Creates and returns an optimizer for training.
@@ -203,7 +194,6 @@ class BasePlugin:
The created optimizer. The created optimizer.
""" """
# pylint: disable=unused-argument
def create_lr_scheduler( def create_lr_scheduler(
self, self,
cfg: DictDefault, cfg: DictDefault,
@@ -223,7 +213,6 @@ class BasePlugin:
The created learning rate scheduler. The created learning rate scheduler.
""" """
# pylint: disable=unused-argument
def add_callbacks_pre_trainer( def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]: ) -> list[Callable]:
@@ -238,7 +227,6 @@ class BasePlugin:
""" """
return [] return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer( def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]: ) -> list[Callable]:
@@ -254,7 +242,6 @@ class BasePlugin:
""" """
return [] return []
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete. """Performs actions after training is complete.
@@ -263,7 +250,7 @@ class BasePlugin:
model: The loaded model. model: The loaded model.
""" """
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument def post_train_unload(self, cfg: DictDefault):
"""Performs actions after training is complete and the model is unloaded. """Performs actions after training is complete and the model is unloaded.
Args: Args:
@@ -311,7 +298,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin return plugin
class PluginManager: # pylint: disable=too-many-public-methods class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It """The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase. should be a singleton so it can be accessed from anywhere in the codebase.

View File

@@ -50,15 +50,9 @@ def merge_input_args():
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {} namespace: Dict[Any, Any] = {}
exec( # pylint: disable=exec-used # nosec B102 exec(dynamic_input, globals(), namespace) # nosec B102
dynamic_input, globals(), namespace AxolotlInputConfig = namespace["AxolotlInputConfig"]
) AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"]
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
@@ -74,7 +68,7 @@ def merge_training_args() -> Type:
Returns: Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins. tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
""" """
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import ( from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase, AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
) )
@@ -93,11 +87,7 @@ def merge_training_args() -> Type:
namespace: Dict[Any, Any] = {} namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase} local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102 exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"]
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase return AxolotlTrainingMixinsBase

View File

@@ -18,6 +18,7 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team. from Apple's ML team.
""" """
import importlib import importlib
from functools import partial from functools import partial
@@ -28,7 +29,7 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs as CutCrossEntropyArgs
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -106,9 +107,7 @@ class CutCrossEntropyPlugin(BasePlugin):
""" """
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic( def patch_generic(maybe_model, patch_options, model_type: str):
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward
@@ -121,12 +120,10 @@ class CutCrossEntropyPlugin(BasePlugin):
) )
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options
patch_options
)
model_cls.forward = cce_forward model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise RuntimeError( raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. " f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -15,6 +15,7 @@
""" """
Module for handling Cut Cross Entropy input arguments. Module for handling Cut Cross Entropy input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator

View File

@@ -1,164 +0,0 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using the LLaDA (Large Language
And Diffusion Assistant) approach within the Axolotl framework.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to see the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
your training config.
## Quickstart
### Basic Configuration
Add the following to your Axolotl configuration YAML:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
# Diffusion-specific configuration
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
# Sample generation (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
# Model configuration
base_model: meta-llama/Llama-3.2-1B
model_type: llama
# Standard Axolotl configuration
datasets:
- path: your_dataset
...
# Other config
sequence_len: 1024
micro_batch_size: 8
gradient_accumulation_steps: 4
learning_rate: 3e-4
```
## Supported Models
Currently supported base model types:
- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
The plugin automatically creates custom model classes that inherit from the base model
while adding diffusion training capabilities. This provides full compatibility with
HuggingFace's ecosystem for saving, loading, and inference.
## How It Works
### Custom Model Architecture
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
standard HuggingFace models. During training, these models:
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
### Random Masking
During training, tokens are randomly masked based on a sampled timestep:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Bidirectional Attention
The models override causal attention with bidirectional attention:
- Creates 4D attention masks allowing all-to-all attention
- Maintains proper padding and sample packing masks
- Compatible with standard HuggingFace attention implementations
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
### Model Loading and Saving
The custom models work seamlessly with HuggingFace's AutoModel system:
```python
from transformers import AutoModel, AutoConfig
# Load a diffusion model
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
# Save a diffusion model
model.save_pretrained("path/to/save/diffusion/model")
```
During inference, the models behave like standard causal language models.
## Sample Generation
When `generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Metrics and Monitoring
The plugin adds several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Benefits of Custom Model Approach
**Type Safety**: Full IDE support and type checking
**HuggingFace Integration**: Works with AutoModel, Hub, pipelines
**Maintainability**: Clean architecture, no monkey patching
**Ecosystem Compatibility**: Standard save/load, PEFT support
**Testing**: Easier to test and debug
## Limitations
- **Model Support**: Currently limited to Llama and Mistral architectures
- **Flash Attention**: Not yet optimized for flash attention
- **Inference Speed**: Bidirectional attention is slower than causal for generation
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)

View File

@@ -1,26 +0,0 @@
"""Diffusion LM training plugin init."""
from transformers import AutoConfig, AutoModel
from .args import DiffusionArgs
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
from .plugin import DiffusionPlugin
# Register custom configurations
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
# Register custom models
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionConfig",
"LlamaForDiffusionConfig",
"MistralForDiffusionConfig",
"LlamaForDiffusionLM",
"MistralForDiffusionLM",
]

View File

@@ -1,70 +0,0 @@
"""Config args for diffusion LM training."""
from typing import Literal
from pydantic import BaseModel, Field
class DiffusionArgs(BaseModel):
"""Arguments for diffusion LM training plugin."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int = Field(
default=128002,
description=(
"Token ID to use for masking. Default is 128002 "
"(<|reserved_special_token_0|> for Llama 3.2)"
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)

View File

@@ -1,116 +0,0 @@
"""Callbacks for diffusion training."""
import wandb
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.logging import get_logger
from .generation import generate_samples
LOG = get_logger(__name__)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if (
state.global_step > 0
and state.global_step % config.get('generation_interval', 100) == 0
):
# Use eval dataloader if available, otherwise use train dataloader
if (
hasattr(self.trainer, "eval_dataset")
and self.trainer.eval_dataset is not None
):
dataloader = self.trainer.get_eval_dataloader()
else:
dataloader = self.trainer.get_train_dataloader()
# Generate samples
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.tokenizer,
dataloader=dataloader,
num_generation_samples=config.get('num_generation_samples', 3),
max_length=config.get('generation_max_length', 256),
num_diffusion_steps=config.get('generation_steps', 10),
temperature=config.get('generation_temperature', 1.0),
mask_token_id=config.get('mask_token_id', 32000),
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
LOG.info("=" * 60)
LOG.info("GENERATED SAMPLES")
LOG.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
LOG.info(f"\nSample {i}:")
LOG.info(f"\tOriginal ({total_tokens} tokens): {original}")
LOG.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
LOG.info(f"\tGenerated: {generated}")
LOG.info("=" * 60)
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -1,71 +0,0 @@
"""Configuration classes for diffusion language models."""
from transformers import LlamaConfig, MistralConfig
class LlamaForDiffusionConfig(LlamaConfig):
"""Configuration class for Llama models with diffusion training."""
model_type = "llama_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
class MistralForDiffusionConfig(MistralConfig):
"""Configuration class for Mistral models with diffusion training."""
model_type = "mistral_diffusion"
def __init__(
self,
mask_token_id: int = 32000,
eps: float = 1e-3,
importance_weighting: bool = False,
sample_packing: bool = False,
min_mask_ratio: float = 0.0,
max_mask_ratio: float = 1.0,
noise_schedule: str = "linear",
**kwargs,
):
super().__init__(**kwargs)
# Diffusion-specific parameters
self.mask_token_id = mask_token_id
self.eps = eps
self.importance_weighting = importance_weighting
self.sample_packing = sample_packing
self.min_mask_ratio = min_mask_ratio
self.max_mask_ratio = max_mask_ratio
self.noise_schedule = noise_schedule
# Keep the base class for backward compatibility but mark as deprecated
class DiffusionConfig(LlamaForDiffusionConfig):
"""
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
"""
model_type = "diffusion"
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,269 +0,0 @@
"""Sample generation utilities for diffusion training."""
import logging
from typing import Any, List, Optional
import torch
logger = logging.getLogger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
logger.warning("No validation dataloader provided, cannot generate samples")
return []
# Get the actual model (unwrap if needed)
unwrapped_model = model.module if hasattr(model, "module") else model
unwrapped_model.eval()
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, unwrapped_model.device
)
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for original_sequence in sampled_sequences:
generation_result = _generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
)
generations.append(generation_result)
unwrapped_model.train()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[torch.Tensor]:
"""Sample sequences from validation dataloader."""
sampled_sequences = []
sample_count = 0
# Add randomness by skipping a random number of batches
skip_batches = torch.randint(0, 6, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
# Limit sequence length to max_length
actual_length = min(seq_len, max_length)
if actual_length < 10: # Skip very short sequences
continue
# Extract the sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def _generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Apply custom masking with random ratio (10% to 70%)
total_tokens = original_sequence.size(1)
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
target_masked_tokens = int(total_tokens * target_mask_ratio)
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
# Clean up mask token representation
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model, sequence, step, num_diffusion_steps, temperature, mask_token_id
)
# Get final generated text
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
return {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
cleaned = " ".join(cleaned.split()).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create bidirectional attention mask for diffusion
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
# Sample predictions
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
# Deterministic sampling when temperature is 0
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -1,426 +0,0 @@
"""Custom model classes for diffusion language models."""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
class DiffusionModelMixin:
"""Mixin class providing diffusion functionality to language models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._special_token_ids = None
def _cache_special_token_ids(self, tokenizer=None):
"""Cache special token IDs to avoid repeated tokenizer access."""
if tokenizer is None:
self._special_token_ids = set()
return
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
# Calculate masking probability with epsilon
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = self.config.mask_token_id
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
return noisy_batch, masked_indices, p_mask
def _create_bidirectional_attention_mask(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking. Handles
sample-packed sequences where different samples are identified by different
attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len]
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not self.config.sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Create attention mask by comparing sample IDs element-wise
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
bidirectional_mask = bidirectional_mask.unsqueeze(1)
return bidirectional_mask
def _compute_diffusion_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
masked_indices: torch.Tensor | None = None,
p_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute diffusion loss given logits and masking information.
Args:
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
logits: Model logits [batch_size, seq_len, vocab_size].
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
"""
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.config.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
# Final loss: sum weighted losses, normalize
if labels is not None:
# For SFT data: normalize by answer length per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
loss_per_sample = torch.zeros(
input_ids.shape[0], device=input_ids.device
)
for i in range(input_ids.shape[0]):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
loss_per_sample[i] = sample_loss / answer_lengths[i]
loss = loss_per_sample.mean()
else:
# Original normalization for non-SFT data
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return loss
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
"""
Llama model for diffusion language modeling.
This model extends LlamaForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = LlamaForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
"""
Mistral model for diffusion language modeling.
This model extends MistralForCausalLM with diffusion training capabilities,
including bidirectional attention and forward diffusion process.
"""
config_class = MistralForDiffusionConfig
def __init__(self, config):
super().__init__(config)
# Initialize diffusion-specific attributes
self._special_token_ids = None
# Initialize weights and apply final processing
self.post_init()
def set_tokenizer(self, tokenizer):
"""Set tokenizer for special token handling."""
self._cache_special_token_ids(tokenizer)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass with diffusion training logic.
During training, applies forward diffusion process and bidirectional attention.
During inference, behaves like standard causal language model.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and input_ids is not None:
# Apply diffusion process during training
original_input_ids = input_ids.clone()
# Apply forward process to get noisy input
noisy_input_ids, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.config.eps
)
# Create bidirectional attention mask
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
input_ids, attention_mask
)
# Forward pass with noisy input and bidirectional attention
outputs = super().forward(
input_ids=noisy_input_ids,
attention_mask=bidirectional_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=None, # Don't use standard loss computation
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Compute diffusion loss
loss = self._compute_diffusion_loss(
original_input_ids,
attention_mask,
labels,
outputs.logits,
masked_indices,
p_mask,
)
if return_dict:
outputs.loss = loss
return outputs
else:
return (loss,) + outputs[1:]
else:
# Standard forward pass for inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1,98 +0,0 @@
"""Diffusion LM training plugin for Axolotl."""
from typing import TYPE_CHECKING
from peft import PeftModel
from transformers import AutoConfig, AutoModel, PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
if TYPE_CHECKING:
from transformers import Trainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def pre_model_load(self, cfg: DictDefault):
"""Configure model loading to use diffusion model classes."""
# Map base model types to diffusion equivalents
base_model_type = cfg.get("model_type")
if base_model_type == "llama":
# Create diffusion config from base config
diffusion_config = LlamaForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "llama_diffusion"
elif base_model_type == "mistral":
# Create diffusion config from base config
diffusion_config = MistralForDiffusionConfig(
mask_token_id=getattr(cfg, "mask_token_id", 32000),
eps=getattr(cfg, "eps", 1e-3),
importance_weighting=getattr(cfg, "importance_weighting", False),
sample_packing=getattr(cfg, "sample_packing", False),
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
)
# Override model type for loading
cfg.model_type = "mistral_diffusion"
else:
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Configure model after loading."""
self.cfg = cfg
# Set tokenizer on diffusion models for special token handling
if hasattr(model, "set_tokenizer"):
# Get tokenizer from cfg if available
tokenizer = getattr(cfg, "tokenizer", None)
if tokenizer is not None:
model.set_tokenizer(tokenizer)
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
"""Add diffusion-specific callbacks after trainer creation."""
callbacks = []
# Store diffusion config on trainer for callbacks
trainer.diffusion_config = cfg
# Add generation callback if enabled
if cfg.get("generate_samples", False):
generation_callback = DiffusionGenerationCallback(trainer)
callbacks.append(generation_callback)
return callbacks

View File

@@ -7,7 +7,7 @@ from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from ..base import BasePlugin from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .args import GrokfastArgs as GrokfastArgs
from .optimizer import gradfilter_ema from .optimizer import gradfilter_ema
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -24,12 +24,10 @@ class GrokfastCallbackHandler(TrainerCallback):
self.alpha = alpha self.alpha = alpha
self.lamb = lamb self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument def on_train_begin(self, *args_, **kwargs):
self.grads = None self.grads = None
def on_pre_optimizer_step( def on_pre_optimizer_step(self, args_, state, control, **kwargs):
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model") model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb) self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control return control

View File

@@ -1,7 +1,6 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee # Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast # Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque from collections import deque
from typing import Dict, Literal, Optional from typing import Dict, Literal, Optional

View File

@@ -15,6 +15,7 @@
""" """
Plugin init to add KD support to Axolotl. Plugin init to add KD support to Axolotl.
""" """
from typing import Any from typing import Any
from transformers import Trainer from transformers import Trainer
@@ -22,7 +23,7 @@ from transformers import Trainer
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 from .args import KDArgs as KDArgs
class KDPlugin(BasePlugin): class KDPlugin(BasePlugin):

View File

@@ -15,6 +15,7 @@
""" """
Plugin args for KD support. Plugin args for KD support.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@@ -26,8 +27,8 @@ class InferenceServerType(str, Enum):
Online inferences server types to handle different request args Online inferences server types to handle different request args
""" """
vllm = "vllm" # pylint: disable=invalid-name vllm = "vllm"
sglang = "sglang" # pylint: disable=invalid-name sglang = "sglang"
class KDArgs(BaseModel): class KDArgs(BaseModel):

View File

@@ -19,9 +19,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback):
self.trainer = trainer self.trainer = trainer
def on_step_end( def on_step_end(self, args, state, control, **kwargs):
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# cosine decay temperature over the max steps # cosine decay temperature over the max steps
progress = state.global_step / state.max_steps progress = state.global_step / state.max_steps

View File

@@ -15,6 +15,7 @@
""" """
Chat template prompt strategy loader with KD support Chat template prompt strategy loader with KD support
""" """
import logging import logging
from typing import Any, Dict from typing import Any, Dict
@@ -192,7 +193,6 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
""" """
Transform logprobs to target format for KD training Transform logprobs to target format for KD training
""" """
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
@@ -240,7 +240,7 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip( for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"] logprobs, sample["target_token_ids"], strict=False
): ):
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
@@ -299,7 +299,7 @@ class KDStrategyLoader(StrategyLoader):
Load ChatTemplateStrategy with KD support using StrategyLoader. Load ChatTemplateStrategy with KD support using StrategyLoader.
""" """
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument def _get_strategy_cls(self, cfg):
return ChatTemplateStrategyWithKD return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
@@ -319,7 +319,7 @@ class KDStrategyLoaderV2(KDStrategyLoader):
Load KD chat template datasets with pre-tokenized logprob data Load KD chat template datasets with pre-tokenized logprob data
""" """
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument def _get_strategy_cls(self, cfg):
return ChatTemplateStrategyWithKDv2 return ChatTemplateStrategyWithKDv2

View File

@@ -37,7 +37,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_logprobs. It also creates a teacher_mask to indicate which entries are valid. target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
""" """
# pylint: disable=duplicate-code
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True padding: Union[bool, str, PaddingStrategy] = True
@@ -72,7 +71,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
// self.pad_to_multiple_of // self.pad_to_multiple_of
) * self.pad_to_multiple_of ) * self.pad_to_multiple_of
for f in features: # pylint: disable=invalid-name for f in features:
remainder = [pad_token_id] * (max_len - len(f[feature_name])) remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list): if isinstance(f[feature_name], list):
f[feature_name] = ( f[feature_name] = (
@@ -101,7 +100,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
if has_teacher_data: if has_teacher_data:
# Extract and remove from features # Extract and remove from features
for f in features: # pylint: disable=invalid-name for f in features:
target_logprobs_list.append(f.pop("target_logprobs")) target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids")) target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask")) target_mask_list.append(f.pop("target_mask"))
@@ -117,24 +116,25 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
padded_teacher_mask_list = [] padded_teacher_mask_list = []
for t_logprobs, t_ids, t_mask in zip( for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list target_logprobs_list,
target_token_ids_list,
target_mask_list,
strict=False,
): ):
t_logprobs_padded = [] t_logprobs_padded = []
t_ids_padded = [] t_ids_padded = []
t_mask_padded = [] t_mask_padded = []
for lp, ids, mask in zip( # pylint: disable=invalid-name for lp, ids, mask in zip(t_logprobs, t_ids, t_mask, strict=False):
t_logprobs, t_ids, t_mask
):
lp_len = len(lp) lp_len = len(lp)
if lp_len < max_k: if lp_len < max_k:
# Use -1e9 for padding logprobs and 0 for token_ids # Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name lp = lp + [-1e9] * pad_len
ids = ids + [0] * pad_len ids = ids + [0] * pad_len
mask = mask + [0] * pad_len mask = mask + [0] * pad_len
else: else:
lp = lp[:max_k] # pylint: disable=invalid-name lp = lp[:max_k]
ids = ids[:max_k] ids = ids[:max_k]
mask = mask[:max_k] mask = mask[:max_k]
@@ -216,9 +216,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch. # We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features] out_features = [{} for _ in features]
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks for i, sub_features in enumerate(features):
features
):
# sub_features is a list of dicts, each dict = one sequences features # sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i]. # We'll merge them into out_features[i].
# #
@@ -255,9 +253,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
if field_name in feat and isinstance( if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor) feat[field_name], (list, torch.Tensor)
): ):
if isinstance( if isinstance(feat[field_name][0], (dict, str)):
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
continue continue
arr = np.array(feat[field_name]) arr = np.array(feat[field_name])
arrays.append(arr) arrays.append(arr)

View File

@@ -144,7 +144,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
} }
for sequence_data, seq_input_ids, seq_labels in zip( for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels api_data, batch_input_ids, labels, strict=False
): ):
current_target_logprobs = [] current_target_logprobs = []
current_target_token_ids = [] current_target_token_ids = []
@@ -165,7 +165,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
assert len(seq_input_ids) == len(input_top_logprobs) assert len(seq_input_ids) == len(input_top_logprobs)
for i, _, label in zip( for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels range(len(seq_input_ids)), seq_input_ids, seq_labels, strict=False
): ):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None: if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token. # this is always the case for the first token.
@@ -202,7 +202,8 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [ pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data) list(row)
for row in zip(*pos_top_logprobs_data, strict=False)
] ]
# Ensure correct length (top_k) # Ensure correct length (top_k)
@@ -317,7 +318,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
} }
for sequence_data, seq_input_ids, seq_labels in zip( for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels choices, batch_input_ids, labels, strict=False
): ):
# seq_input_ids: List[int] # seq_input_ids: List[int]
# seq_labels: List[int] # seq_labels: List[int]
@@ -342,7 +343,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
seq_len = len(seq_input_ids) seq_len = len(seq_input_ids)
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels): for i, _, label in zip(
range(seq_len), seq_input_ids, seq_labels, strict=False
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None: if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token. # this is always the case for the first token.
# there is never logprob data for the first token since that's a true input # there is never logprob data for the first token since that's a true input
@@ -424,7 +427,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk)) list(range(self.kd_online_topk))
) )
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
for i in range(max(0, seq_len - len(current_target_logprobs))): for _ in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append( current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk [-float("inf")] * self.kd_online_topk
) )

View File

@@ -197,7 +197,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss: bool = True, compute_ce_loss: bool = True,
normalize_topk: bool = True, normalize_topk: bool = True,
): ):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name CHUNK_SIZE = chunk_size
grad_weight_acc = torch.zeros_like(student_lm_head_weight) grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = [] grad_inputs_list = []
grad_bias_acc = ( grad_bias_acc = (
@@ -298,8 +298,8 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
accumulate_chunk_grads_compiled = accumulate_chunk_grads accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward # Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape # pylint: disable=invalid-name B, N, D = student_input.shape
K = target_token_ids.shape[-1] # pylint: disable=invalid-name K = target_token_ids.shape[-1]
student_input_flat = student_input.reshape(-1, student_input.shape[-1]) student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])

View File

@@ -40,10 +40,9 @@ def kldiv_forward_llama_like(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc] **kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = ( output_attentions = (
output_attentions output_attentions
if output_attentions is not None if output_attentions is not None

View File

@@ -15,6 +15,7 @@
""" """
loss for top_k KL divergence loss for top_k KL divergence
""" """
import torch import torch
from torch import nn from torch import nn
@@ -117,7 +118,6 @@ class ChunkedTopKKDLoss(nn.Module):
target_mask: torch.Tensor, # [B, seq_len, K] target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor: ) -> torch.Tensor:
# 1. Split along the "token" dimension (dim=1). # 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1) student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1) token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
@@ -131,7 +131,11 @@ class ChunkedTopKKDLoss(nn.Module):
# 2. Loop over each chunk and compute a chunk-specific loss. # 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip( for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks student_logits_chunks,
token_ids_chunks,
logprobs_chunks,
mask_chunks,
strict=False,
): ):
# We pass num_items_in_batch=-1 so that the kd_loss # We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only. # will average over *this chunk's* valid tokens only.

View File

@@ -21,7 +21,6 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -18,6 +18,7 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training. Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
from .args import LigerArgs from .args import LigerArgs
from .plugin import LigerPlugin from .plugin import LigerPlugin

View File

@@ -41,7 +41,6 @@ def lce_forward(
This is useful when using packed tensor format (single dimension for batch and sequence length). This is useful when using packed tensor format (single dimension for batch and sequence length).
""" """
# pylint: disable=duplicate-code
output_attentions = ( output_attentions = (
output_attentions output_attentions
if output_attentions is not None if output_attentions is not None
@@ -181,7 +180,7 @@ def patch_lce_forward(
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward model_cls.forward = lce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise RuntimeError( raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. " f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -2,8 +2,6 @@
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch

View File

@@ -2,8 +2,6 @@
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch

View File

@@ -46,7 +46,6 @@ def lce_forward(
Returns: Returns:
""" """
# pylint: disable=duplicate-code
output_attentions = ( output_attentions = (
output_attentions output_attentions
if output_attentions is not None if output_attentions is not None
@@ -78,9 +77,7 @@ def lce_forward(
hidden_states = outputs[0] hidden_states = outputs[0]
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
raise Exception( # pylint: disable=broad-exception-raised raise Exception("Liger Kernel does not support pretraining_tp!!")
"Liger Kernel does not support pretraining_tp!!"
)
logits = None logits = None
loss = None loss = None
@@ -128,7 +125,7 @@ def apply_liger_kernel_to_llama4(
rms_norm: bool = False, rms_norm: bool = False,
glu_activation: bool = False, glu_activation: bool = False,
layer_norm: bool = False, layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs,
) -> None: ) -> None:
""" """
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -144,15 +141,15 @@ def apply_liger_kernel_to_llama4(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
""" """
import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import import transformers.models.llama4.modeling_llama4 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not ( assert not (cross_entropy and fused_linear_cross_entropy), (
cross_entropy and fused_linear_cross_entropy "cross_entropy and fused_linear_cross_entropy cannot both be True."
), "cross_entropy and fused_linear_cross_entropy cannot both be True." )
modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"] modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"]
@@ -165,7 +162,7 @@ def apply_liger_kernel_to_llama4(
# clone config to avoid modifying the original # clone config to avoid modifying the original
config = deepcopy(config) config = deepcopy(config)
if intermediate_size: if intermediate_size:
setattr(config, "intermediate_size", intermediate_size) config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs) return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper

View File

@@ -43,7 +43,6 @@ def lce_forward(
Returns: Returns:
""" """
# pylint: disable=duplicate-code
output_attentions = ( output_attentions = (
output_attentions output_attentions
if output_attentions is not None if output_attentions is not None
@@ -113,9 +112,8 @@ def apply_liger_kernel_to_qwen3(
rms_norm: bool = False, rms_norm: bool = False,
glu_activation: bool = False, glu_activation: bool = False,
layer_norm: bool = False, layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs,
) -> None: ) -> None:
# pylint: disable=duplicate-code
""" """
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -130,15 +128,15 @@ def apply_liger_kernel_to_qwen3(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
""" """
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import import transformers.models.qwen3.modeling_qwen3 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not ( assert not (cross_entropy and fused_linear_cross_entropy), (
cross_entropy and fused_linear_cross_entropy "cross_entropy and fused_linear_cross_entropy cannot both be True."
), "cross_entropy and fused_linear_cross_entropy cannot both be True." )
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"] modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]

View File

@@ -45,7 +45,6 @@ def lce_forward(
Returns: Returns:
""" """
# pylint: disable=duplicate-code
output_attentions = ( output_attentions = (
output_attentions output_attentions
if output_attentions is not None if output_attentions is not None
@@ -135,9 +134,8 @@ def apply_liger_kernel_to_qwen3_moe(
rms_norm: bool = False, rms_norm: bool = False,
glu_activation: bool = False, glu_activation: bool = False,
layer_norm: bool = False, layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs,
) -> None: ) -> None:
# pylint: disable=duplicate-code
""" """
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -152,15 +150,15 @@ def apply_liger_kernel_to_qwen3_moe(
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
""" """
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not ( assert not (cross_entropy and fused_linear_cross_entropy), (
cross_entropy and fused_linear_cross_entropy "cross_entropy and fused_linear_cross_entropy cannot both be True."
), "cross_entropy and fused_linear_cross_entropy cannot both be True." )
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"] modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
@@ -174,7 +172,7 @@ def apply_liger_kernel_to_qwen3_moe(
# clone config to avoid modifying the original # clone config to avoid modifying the original
config = deepcopy(config) config = deepcopy(config)
if intermediate_size: if intermediate_size:
setattr(config, "intermediate_size", intermediate_size) config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs) return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper

View File

@@ -7,7 +7,7 @@ import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 from .args import LMEvalArgs as LMEvalArgs
class LMEvalPlugin(BasePlugin): class LMEvalPlugin(BasePlugin):
@@ -20,7 +20,6 @@ class LMEvalPlugin(BasePlugin):
def post_train_unload(self, cfg): def post_train_unload(self, cfg):
if cfg.lm_eval_post_train: if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command( for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks, cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16, bfloat16=cfg.bfloat16 or cfg.bf16,

View File

@@ -99,7 +99,6 @@ def lm_eval(config: str, cloud: Optional[str] = None):
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command( for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks, cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16, bfloat16=cfg.bfloat16 or cfg.bf16,

View File

@@ -23,7 +23,7 @@ import requests
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 from .args import SpectrumArgs as SpectrumArgs
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -46,7 +46,7 @@ def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
"^lm_head.weight$", "^lm_head.weight$",
"^model.embed_tokens.weight$", "^model.embed_tokens.weight$",
] ]
for layer_type, layer_names in top_layers_by_type.items(): for _, layer_names in top_layers_by_type.items():
for layer_name in layer_names: for layer_name in layer_names:
unfrozen_parameters.append(layer_name) unfrozen_parameters.append(layer_name)
return unfrozen_parameters return unfrozen_parameters
@@ -84,7 +84,7 @@ class SpectrumPlugin(BasePlugin):
snr_data = json.load(fin) snr_data = json.load(fin)
except FileNotFoundError: except FileNotFoundError:
pass pass
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc:
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}") LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data: if not snr_data:

View File

@@ -15,6 +15,7 @@
""" """
Module for handling Spectrum input arguments. Module for handling Spectrum input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator

View File

@@ -5,8 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -7,8 +7,6 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name
from typing import Callable from typing import Callable
import torch import torch

View File

@@ -1,7 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes import ctypes
import bitsandbytes as bnb import bitsandbytes as bnb

View File

@@ -99,7 +99,6 @@ def _swiglu_bwd_kernel(
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
# pylint: disable=unnecessary-lambda-assignment
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
""" """
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
@@ -128,7 +127,6 @@ def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
return out return out
# pylint: disable=unnecessary-lambda-assignment
def swiglu_backward( def swiglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.loaders module""" """Init for axolotl.loaders module"""
# pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .adapter import load_adapter, load_lora from .adapter import load_adapter, load_lora

View File

@@ -28,14 +28,12 @@ LOG = get_logger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module): def setup_quantized_meta_for_peft(model: torch.nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument def temp_to_method(self, *args, **kwargs):
return self return self
for param in model.parameters(): for param in model.parameters():
if isinstance(param, Params4bit): if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access param.quant_state._orig_to = param.quant_state.to
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
@@ -43,10 +41,8 @@ def setup_quantized_peft_meta_for_training(model: torch.nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue""" """Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters(): for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = ( param.quant_state.to = param.quant_state._orig_to
param.quant_state._orig_to # pylint: disable=protected-access param.quant_state._orig_to = None
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def find_all_linear_names(model): def find_all_linear_names(model):

Some files were not shown because too many files have changed in this diff Show More