diff --git a/scripts/build_scattermoe_lora_kernel.py b/scripts/build_scattermoe_lora_kernel.py new file mode 100644 index 000000000..a27f9fbec --- /dev/null +++ b/scripts/build_scattermoe_lora_kernel.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +"""Build a disposable Hugging Face Kernel Hub package for ScatterMoE LoRA. + +This script does not move or edit the in-tree Axolotl kernel sources. It copies +``src/axolotl/integrations/kernels/libs/scattermoe_lora`` into an ignored +build directory and emits a universal HF kernels project that can be pushed to +the Hub. +""" + +from __future__ import annotations + +import argparse +import fnmatch +import hashlib +import json +import os +import shutil +import subprocess +import sys +from importlib import metadata +from pathlib import Path + +PACKAGE_NAME = "scattermoe_lora" +BUILD_VARIANT = "torch-universal" +DEFAULT_REPO_ID = "kernels-community/scattermoe-lora" +HF_REPO_TYPE = "kernel" + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_SOURCE_DIR = ( + REPO_ROOT / "src" / "axolotl" / "integrations" / "kernels" / "libs" / PACKAGE_NAME +) +DEFAULT_OUTPUT_DIR = REPO_ROOT / "build" / "hf-kernels" / PACKAGE_NAME + +EXCLUDED_DIRS = { + "__pycache__", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", +} +EXCLUDED_FILE_PATTERNS = { + "*.pyc", + "*.pyo", + "*.so", + ".DS_Store", +} + +TEXT_REPLACEMENTS = { + "from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import": ( + "from .selective_dequant import" + ), + "from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import": ( + "from .selective_dequant_kernel import" + ), + "from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import": ( + "from .ops import" + ), +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Copy Axolotl's ScatterMoE LoRA Triton kernels into a disposable " + "HF Kernel Hub universal package." + ) + ) + parser.add_argument( + "--source-dir", + type=Path, + default=DEFAULT_SOURCE_DIR, + help=f"ScatterMoE LoRA source package to copy. Default: {DEFAULT_SOURCE_DIR}", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=DEFAULT_OUTPUT_DIR, + help=f"Destination build/dist directory. Default: {DEFAULT_OUTPUT_DIR}", + ) + parser.add_argument( + "--repo-id", + default=DEFAULT_REPO_ID, + help=f"HF Hub repo id to write into build.toml. Default: {DEFAULT_REPO_ID}", + ) + parser.add_argument( + "--version", + type=int, + default=1, + help="Kernel major version written to build.toml and metadata.json.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Delete the output directory first if it already exists.", + ) + parser.add_argument( + "--no-source-layout", + action="store_true", + help="Only write the shippable build/ tree, not torch-ext/ sources.", + ) + parser.add_argument( + "--upload", + action="store_true", + help=( + "Upload the generated universal kernel package with huggingface_hub. " + "This bypasses kernel-builder and is intended for pure Python/Triton " + "universal kernels." + ), + ) + parser.add_argument( + "--private", + action="store_true", + help="Create the HF Hub repo as private when used with --upload.", + ) + parser.add_argument( + "--skip-version-branch", + action="store_true", + help="With --upload, only upload main and skip the v branch.", + ) + return parser.parse_args() + + +def should_skip_file(path: Path) -> bool: + return any( + fnmatch.fnmatch(path.name, pattern) for pattern in EXCLUDED_FILE_PATTERNS + ) + + +def iter_source_files(source_dir: Path) -> list[Path]: + files: list[Path] = [] + for root, dirs, filenames in os.walk(source_dir): + dirs[:] = sorted(d for d in dirs if d not in EXCLUDED_DIRS) + for filename in sorted(filenames): + path = Path(root) / filename + if not should_skip_file(path): + files.append(path) + return files + + +def content_hash(source_dir: Path) -> str: + digest = hashlib.sha1() + for path in iter_source_files(source_dir): + rel = path.relative_to(source_dir).as_posix() + digest.update(rel.encode("utf-8")) + digest.update(b"\0") + digest.update(path.read_bytes()) + digest.update(b"\0") + return digest.hexdigest()[:10] + + +def git_revision() -> str: + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + cwd=REPO_ROOT, + check=True, + capture_output=True, + text=True, + ) + except (OSError, subprocess.CalledProcessError): + return "unknown" + return result.stdout.strip() or "unknown" + + +def transform_python_source(text: str, rel_path: Path, op_namespace: str) -> str: + for old, new in TEXT_REPLACEMENTS.items(): + text = text.replace(old, new) + + if rel_path.as_posix() == "gemma4_experts.py": + text = text.replace( + " from axolotl.integrations.kernels.constants import resolve_experts_class", + ( + " raise RuntimeError(\n" + ' "patch_gemma4_scattermoe is only available from the in-tree Axolotl "\n' + ' "integration. Use register_scattermoe_experts() with the standalone "\n' + ' "HF kernel package."\n' + " )" + ), + ) + + return text.replace("scattermoe::", f"{op_namespace}::") + + +def copy_package(source_dir: Path, package_dir: Path, op_namespace: str) -> None: + for source in iter_source_files(source_dir): + rel_path = source.relative_to(source_dir) + destination = package_dir / rel_path + destination.parent.mkdir(parents=True, exist_ok=True) + + if source.suffix == ".py": + text = source.read_text(encoding="utf-8") + text = transform_python_source(text, rel_path, op_namespace) + destination.write_text(text, encoding="utf-8") + else: + shutil.copy2(source, destination) + + write_ops_module(package_dir / "_ops.py", op_namespace) + + +def write_ops_module(path: Path, op_namespace: str) -> None: + path.write_text( + "\n".join( + [ + "import torch", + "", + f"ops = torch.ops.{op_namespace}", + "", + "", + "def add_op_namespace_prefix(op_name: str) -> str:", + f' return f"{op_namespace}::{{op_name}}"', + "", + ] + ), + encoding="utf-8", + ) + + +def write_build_toml(path: Path, repo_id: str, version: int) -> None: + lines = [ + "[general]", + f'name = "{PACKAGE_NAME}"', + "universal = true", + f"version = {version}", + "", + ] + if repo_id: + lines.extend( + [ + "[general.hub]", + f'repo-id = "{repo_id}"', + "", + ] + ) + path.write_text("\n".join(lines), encoding="utf-8") + + +def write_flake(path: Path) -> None: + path.write_text( + """{ + description = "Flake for scattermoe_lora kernel"; + + inputs = { + builder.url = "github:huggingface/kernels"; + }; + + outputs = + { + self, + builder, + }: + builder.lib.genKernelFlakeOutputs { + inherit self; + path = ./.; + }; +} +""", + encoding="utf-8", + ) + + +def write_readme(path: Path, repo_id: str, source_hash: str, op_namespace: str) -> None: + repo_display = repo_id or "/scattermoe-lora" + path.write_text( + f"""--- +library_name: kernels +license: apache-2.0 +tags: +- kernel +- kernels +--- + +# ScatterMoE LoRA + +Standalone Hugging Face Kernel Hub package for Axolotl's ScatterMoE LoRA Triton kernels. + +This package is generated from Axolotl's in-tree `scattermoe_lora` sources and is exported as a universal kernel because the implementation is Python/Triton rather than a precompiled C++/CUDA extension. + +```python +from kernels import get_kernel + +scattermoe_lora = get_kernel("{repo_display}") +``` + +Export metadata: + +- source package: `src/axolotl/integrations/kernels/libs/scattermoe_lora` +- source revision: `{git_revision()}` +- source content hash: `{source_hash}` +- torch custom op namespace: `{op_namespace}` + +The generated `build/torch-universal/{PACKAGE_NAME}` directory is the shippable Hub artifact. `torch-ext/{PACKAGE_NAME}` is included so `kernel-builder build-and-copy` can regenerate the universal build tree if desired. +""", + encoding="utf-8", + ) + + +def write_metadata(path: Path, version: int) -> None: + path.write_text( + json.dumps({"version": version}, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +def prepare_output_dir(output_dir: Path, force: bool) -> None: + if output_dir.exists(): + if not force: + raise FileExistsError( + f"{output_dir} already exists. Re-run with --force to replace it." + ) + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True) + + +def build_package(args: argparse.Namespace) -> Path: + source_dir = args.source_dir.resolve() + output_dir = args.output_dir.resolve() + + if not source_dir.is_dir(): + raise FileNotFoundError(f"source package does not exist: {source_dir}") + if not (source_dir / "__init__.py").is_file(): + raise FileNotFoundError(f"source package is missing __init__.py: {source_dir}") + + source_hash = content_hash(source_dir) + op_namespace = f"_{PACKAGE_NAME}_{source_hash}" + + prepare_output_dir(output_dir, args.force) + + write_build_toml(output_dir / "build.toml", args.repo_id, args.version) + write_flake(output_dir / "flake.nix") + write_readme(output_dir / "README.md", args.repo_id, source_hash, op_namespace) + + if not args.no_source_layout: + copy_package(source_dir, output_dir / "torch-ext" / PACKAGE_NAME, op_namespace) + + build_package_dir = output_dir / "build" / BUILD_VARIANT / PACKAGE_NAME + copy_package(source_dir, build_package_dir, op_namespace) + write_metadata(build_package_dir.parent / "metadata.json", args.version) + + return output_dir + + +def upload_package(args: argparse.Namespace, output_dir: Path) -> None: + if not args.repo_id: + raise ValueError("--repo-id is required when using --upload") + + try: + from huggingface_hub import HfApi, constants as hf_constants + except ImportError as exc: + raise RuntimeError( + "--upload requires huggingface_hub. Install it or run the upload " + "manually with the Hugging Face CLI." + ) from exc + + accepted_repo_types = getattr( + hf_constants, + "REPO_TYPES_WITH_KERNEL", + getattr(hf_constants, "REPO_TYPES", ()), + ) + if HF_REPO_TYPE not in accepted_repo_types: + try: + hub_version = metadata.version("huggingface_hub") + except metadata.PackageNotFoundError: + hub_version = "unknown" + raise RuntimeError( + "Your huggingface_hub installation does not support " + f"repo_type={HF_REPO_TYPE!r} (found huggingface_hub {hub_version}). " + "Upgrade with: python -m pip install --upgrade " + "'huggingface_hub>=1.10.0'" + ) + + api = HfApi() + repo_id = api.create_repo( + repo_id=args.repo_id, + repo_type=HF_REPO_TYPE, + private=args.private, + exist_ok=True, + ).repo_id + + delete_patterns = [ + "build/**", + "torch-ext/**", + "build.toml", + "flake.nix", + "README.md", + ] + + api.upload_folder( + repo_id=repo_id, + repo_type=HF_REPO_TYPE, + folder_path=output_dir, + revision="main", + delete_patterns=delete_patterns, + commit_message="Upload ScatterMoE LoRA universal kernel", + ) + print(f"Uploaded main branch: https://hf.co/{repo_id}") + + if args.skip_version_branch: + return + + version_branch = f"v{args.version}" + api.create_branch( + repo_id=repo_id, + repo_type=HF_REPO_TYPE, + branch=version_branch, + revision="main", + exist_ok=True, + ) + api.upload_folder( + repo_id=repo_id, + repo_type=HF_REPO_TYPE, + folder_path=output_dir, + revision=version_branch, + delete_patterns=delete_patterns, + commit_message=f"Upload ScatterMoE LoRA universal kernel {version_branch}", + ) + print(f"Uploaded version branch: https://hf.co/{repo_id}/tree/{version_branch}") + + +def main() -> int: + args = parse_args() + try: + output_dir = build_package(args) + if args.upload: + upload_package(args, output_dir) + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + print(f"Wrote ScatterMoE LoRA HF kernel package to: {output_dir}") + print(f"Shippable artifact: {output_dir / 'build' / BUILD_VARIANT / PACKAGE_NAME}") + if args.upload: + print(f'Load it with: get_kernel("{args.repo_id}", version={args.version})') + print(f"Uploaded as Hugging Face repo_type={HF_REPO_TYPE!r}.") + return 0 + + print("Next step:") + print(" upload this universal Python/Triton kernel directly:") + print( + f" python3 {Path(__file__).as_posix()} " + f"--repo-id {args.repo_id} --force --upload" + ) + if shutil.which("kernel-builder") is None: + print(" optional: install kernel-builder for full Nix-based builds:") + print( + " curl -fsSL " + "https://raw.githubusercontent.com/huggingface/kernels/main/install.sh " + "| bash" + ) + else: + print(" optional: upload with kernel-builder:") + print(f" cd {output_dir}") + print(" kernel-builder build-and-upload") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())