From 05113bc91a5b8393918a618255022709e1474a6b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 22 Apr 2026 01:14:41 -0400 Subject: [PATCH] train on remote compute using Tinker compatible APIs (#3614) * train on remote compute using Tinker compatible APIs * chore: lint * fixes with latest hatchery changes * chore: lint --- src/axolotl/integrations/hatchery/__init__.py | 27 ++ src/axolotl/integrations/hatchery/args.py | 62 +++ src/axolotl/integrations/hatchery/data.py | 160 +++++++ .../hatchery/examples/prep_math_rl.py | 87 ++++ .../hatchery/examples/tinker_rl.yaml | 47 ++ .../hatchery/examples/tinker_sft.yaml | 42 ++ src/axolotl/integrations/hatchery/plugin.py | 147 +++++++ .../integrations/hatchery/rewards/__init__.py | 3 + .../hatchery/rewards/math_reward.py | 78 ++++ .../integrations/hatchery/rl_trainer.py | 409 ++++++++++++++++++ src/axolotl/integrations/hatchery/trainer.py | 327 ++++++++++++++ 11 files changed, 1389 insertions(+) create mode 100644 src/axolotl/integrations/hatchery/__init__.py create mode 100644 src/axolotl/integrations/hatchery/args.py create mode 100644 src/axolotl/integrations/hatchery/data.py create mode 100644 src/axolotl/integrations/hatchery/examples/prep_math_rl.py create mode 100644 src/axolotl/integrations/hatchery/examples/tinker_rl.yaml create mode 100644 src/axolotl/integrations/hatchery/examples/tinker_sft.yaml create mode 100644 src/axolotl/integrations/hatchery/plugin.py create mode 100644 src/axolotl/integrations/hatchery/rewards/__init__.py create mode 100644 src/axolotl/integrations/hatchery/rewards/math_reward.py create mode 100644 src/axolotl/integrations/hatchery/rl_trainer.py create mode 100644 src/axolotl/integrations/hatchery/trainer.py diff --git a/src/axolotl/integrations/hatchery/__init__.py b/src/axolotl/integrations/hatchery/__init__.py new file mode 100644 index 000000000..c0d8510db --- /dev/null +++ b/src/axolotl/integrations/hatchery/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Hatchery/Tinker remote training integration for Axolotl. + +Routes axolotl's preprocessed data to a remote training API (Tinker or +Hatchery) instead of running forward/backward locally. The remote +service handles model weights, LoRA adapters, and gradient updates. +""" + +from .args import HatcheryArgs, HatcheryConfig +from .plugin import HatcheryPlugin + +__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"] + +# Usage: +# plugins: +# - axolotl.integrations.hatchery.HatcheryPlugin +# +# hatchery: +# backend: tinker # or "hatchery" +# lora_rank: 32 +# loss_fn: cross_entropy # SFT +# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer) +# +# learning_rate: 1e-4 # top-level, not under hatchery: diff --git a/src/axolotl/integrations/hatchery/args.py b/src/axolotl/integrations/hatchery/args.py new file mode 100644 index 000000000..e3fdb95a2 --- /dev/null +++ b/src/axolotl/integrations/hatchery/args.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Pydantic config schema for the Hatchery integration.""" + +from __future__ import annotations + +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + + +class HatcheryConfig(BaseModel): + """Nested config under `hatchery:` in the axolotl YAML. + + Only contains hatchery-specific settings. Standard training params + (learning_rate, weight_decay, adam_beta1/2, max_grad_norm, + gradient_accumulation_steps) are read from axolotl's top-level config. + """ + + # Backend & connection + backend: Literal["tinker", "hatchery"] = "tinker" + base_url: Optional[str] = None + api_key: Optional[str] = None + project_id: Optional[str] = None + + # LoRA config sent to remote + lora_rank: int = Field(32, ge=1, le=256) + train_attn: bool = True + train_mlp: bool = True + train_unembed: bool = True + + # Loss function + loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = ( + "cross_entropy" + ) + loss_fn_config: Optional[dict[str, Any]] = None + + # Pipelining: submit next batch before awaiting previous result + pipeline: bool = True + + # Sampling params (for RL flows) + max_sample_tokens: int = 256 + sample_temperature: float = 1.0 + num_samples: int = 4 + + # Reward functions (for RL) — list of fully qualified names + reward_funcs: Optional[list[str]] = None + + # Checkpointing + save_steps: Optional[int] = None + save_name_prefix: str = "checkpoint" + + # Timeout per future (seconds) + future_timeout: float = 600.0 + + +class HatcheryArgs(BaseModel): + """Top-level mixin that adds the nested `hatchery:` field.""" + + hatchery: Optional[HatcheryConfig] = None diff --git a/src/axolotl/integrations/hatchery/data.py b/src/axolotl/integrations/hatchery/data.py new file mode 100644 index 000000000..f7baa2cca --- /dev/null +++ b/src/axolotl/integrations/hatchery/data.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Convert axolotl batch tensors to Tinker/Hatchery Datum format. + +Both Tinker and Hatchery expect the client to apply the causal LM shift: + + Original tokens: [t0, t1, t2, ..., t_{L-1}] + model_input: [t0, t1, ..., t_{L-2}] (last token dropped) + target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped) + weights: [w1, w2, ..., w_{L-1}] (aligned to targets) + +At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}. +""" + +from __future__ import annotations + +from typing import Any + +import torch + + +def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]: + """Serialize a tensor to the TensorData wire dict.""" + flat = t.detach().cpu().flatten() + dtype_map = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.int64: "int64", + torch.int32: "int32", + } + return { + "dtype": dtype_map.get(flat.dtype, "float32"), + "shape": list(t.shape), + "data": flat.tolist(), + } + + +def _make_datum( + tokens: list[int], + loss_fn_inputs: dict[str, torch.Tensor], +) -> dict[str, Any]: + """Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery).""" + return { + "model_input": { + "chunks": [{"type": "encoded_text", "tokens": tokens}], + }, + "loss_fn_inputs": { + key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items() + }, + } + + +def datums_to_tinker(datums: list[dict[str, Any]]): + """Wrap plain-dict datums into tinker.types.Datum objects. + + Both the Tinker SDK and updated Hatchery client accept these. + """ + import tinker.types as tt + + result = [] + for d in datums: + tokens = d["model_input"]["chunks"][0]["tokens"] + tinker_inputs = {} + for key, wire in d["loss_fn_inputs"].items(): + tinker_inputs[key] = tt.TensorData( + data=wire["data"], + dtype=wire["dtype"], + shape=wire["shape"], + ) + result.append( + tt.Datum( + model_input=tt.ModelInput.from_ints(tokens), + loss_fn_inputs=tinker_inputs, + ) + ) + return result + + +def batch_to_datums_sft( + input_ids: torch.Tensor, + labels: torch.Tensor, + attention_mask: torch.Tensor | None = None, +) -> list[dict[str, Any]]: + """Convert an axolotl SFT batch to Datum dicts with causal shift.""" + batch_size = input_ids.size(0) + datums = [] + + for i in range(batch_size): + ids = input_ids[i] + lbl = labels[i] + + if attention_mask is not None: + seq_len = int(attention_mask[i].sum().item()) + ids = ids[:seq_len] + lbl = lbl[:seq_len] + + model_tokens = ids[:-1].tolist() + shifted_labels = lbl[1:] + + target_tokens = shifted_labels.clone() + weights = (shifted_labels != -100).float() + target_tokens[target_tokens == -100] = 0 + + datums.append( + _make_datum( + model_tokens, + { + "target_tokens": target_tokens, + "weights": weights, + }, + ) + ) + + return datums + + +def batch_to_datums_rl( + input_ids: torch.Tensor, + labels: torch.Tensor, + logprobs: torch.Tensor, + advantages: torch.Tensor, + attention_mask: torch.Tensor | None = None, +) -> list[dict[str, Any]]: + """Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift.""" + batch_size = input_ids.size(0) + datums = [] + + for i in range(batch_size): + ids = input_ids[i] + lbl = labels[i] + + if attention_mask is not None: + seq_len = int(attention_mask[i].sum().item()) + else: + seq_len = ids.size(0) + ids = ids[:seq_len] + lbl = lbl[:seq_len] + lp = logprobs[i, :seq_len] + adv = advantages[i, :seq_len] + + model_tokens = ids[:-1].tolist() + + target_tokens = lbl[1:].clone() + target_tokens[target_tokens == -100] = 0 + + datums.append( + _make_datum( + model_tokens, + { + "target_tokens": target_tokens, + "logprobs": lp[1:], + "advantages": adv[1:], + }, + ) + ) + + return datums diff --git a/src/axolotl/integrations/hatchery/examples/prep_math_rl.py b/src/axolotl/integrations/hatchery/examples/prep_math_rl.py new file mode 100644 index 000000000..183815907 --- /dev/null +++ b/src/axolotl/integrations/hatchery/examples/prep_math_rl.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Prepare hendrycks_math for RL training with Hatchery/Tinker. + +Creates a dataset with chat-formatted prompts that include +a hidden gold answer tag for the reward function. + +Run: + python src/axolotl/integrations/hatchery/examples/prep_math_rl.py +""" + +import os +import re + +from datasets import Dataset, load_dataset +from transformers import AutoTokenizer + + +def extract_boxed(text: str) -> str: + match = re.search(r"\\boxed\{", text) + if not match: + return "" + start = match.end() + depth = 1 + i = start + while i < len(text) and depth > 0: + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + i += 1 + return text[start : i - 1] if depth == 0 else "" + + +def main(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True) + + ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test") + level = os.environ.get("MATH_LEVEL", "Level 1") + filtered_rows = [x for x in ds if x["level"] == level] + print(f"{level} algebra: {len(filtered_rows)} problems") + + rows = [] + for prob in filtered_rows: + gold = extract_boxed(prob["solution"]) + if not gold: + continue + + # Format as chat prompt with hidden gold tag + prompt = ( + f"Solve the following math problem. " + f"Show your work and put your final answer in \\boxed{{}}.\n\n" + f"{prob['problem']}" + f"<|gold|>{gold}<|/gold|>" + ) + + # Tokenize the prompt + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) + prompt_ids = tokenizer.encode(text, add_special_tokens=False) + + rows.append( + { + "input_ids": prompt_ids, + "labels": [-100] * len(prompt_ids), + "attention_mask": [1] * len(prompt_ids), + } + ) + + out = Dataset.from_list(rows) + out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}" + out.save_to_disk(out_dir) + print(f"Saved {len(out)} examples to {out_dir}") + if rows: + print( + f"Prompt length range: {min(len(r['input_ids']) for r in rows)}" + f"-{max(len(r['input_ids']) for r in rows)}" + ) + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/integrations/hatchery/examples/tinker_rl.yaml b/src/axolotl/integrations/hatchery/examples/tinker_rl.yaml new file mode 100644 index 000000000..caab3fe0d --- /dev/null +++ b/src/axolotl/integrations/hatchery/examples/tinker_rl.yaml @@ -0,0 +1,47 @@ +# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B +# +# Prep: +# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py +# +# Run: +# export TINKER_API_KEY="your-key" +# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml + +base_model: Qwen/Qwen3-8B + +plugins: + - axolotl.integrations.hatchery.HatcheryPlugin + +hatchery: + backend: tinker + lora_rank: 16 + loss_fn: importance_sampling + max_sample_tokens: 2048 + sample_temperature: 0.7 + num_samples: 4 + pipeline: true + save_steps: 5 + reward_funcs: + - axolotl.integrations.hatchery.rewards.math_reward.math_reward + +datasets: + - path: ./data/math_rl_level1 + ds_type: arrow + type: completion + +sequence_len: 2048 + +learning_rate: 5.0e-5 +optimizer: adamw_torch +adam_beta1: 0.9 +adam_beta2: 0.95 +weight_decay: 0.01 +max_grad_norm: 1.0 + +max_steps: 10 +num_epochs: 1 +micro_batch_size: 1 +gradient_accumulation_steps: 1 +logging_steps: 1 + +output_dir: ./outputs/tinker-rl-math diff --git a/src/axolotl/integrations/hatchery/examples/tinker_sft.yaml b/src/axolotl/integrations/hatchery/examples/tinker_sft.yaml new file mode 100644 index 000000000..d99f043ae --- /dev/null +++ b/src/axolotl/integrations/hatchery/examples/tinker_sft.yaml @@ -0,0 +1,42 @@ +# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B +# +# Usage: +# export TINKER_API_KEY="your-key" +# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml + +base_model: Qwen/Qwen3-8B + +plugins: + - axolotl.integrations.hatchery.HatcheryPlugin + +hatchery: + backend: tinker + lora_rank: 16 + loss_fn: cross_entropy + pipeline: true + save_steps: 10 + +datasets: + - path: TeichAI/kimi-k2-thinking-1000x + split: train[:50] + type: chat_template + chat_template: qwen3 + split_thinking: true + +chat_template: qwen3 +sequence_len: 2048 + +learning_rate: 3.0e-4 +optimizer: adamw_torch +adam_beta1: 0.9 +adam_beta2: 0.95 +weight_decay: 0.01 +max_grad_norm: 1.0 + +num_epochs: 1 +max_steps: 20 +micro_batch_size: 2 +gradient_accumulation_steps: 1 +logging_steps: 1 + +output_dir: ./outputs/tinker-sft diff --git a/src/axolotl/integrations/hatchery/plugin.py b/src/axolotl/integrations/hatchery/plugin.py new file mode 100644 index 000000000..1546958e8 --- /dev/null +++ b/src/axolotl/integrations/hatchery/plugin.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Axolotl plugin that routes training to a remote Hatchery/Tinker API.""" + +from __future__ import annotations + +import torch +from peft import PeftModel +from transformers import AutoConfig, PreTrainedModel, Trainer + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class HatcheryPlugin(BasePlugin): + """Plugin that replaces local training with remote API calls. + + Activated by adding to the axolotl YAML: + + plugins: + - axolotl.integrations.hatchery.HatcheryPlugin + + hatchery: + backend: tinker # or "hatchery" + lora_rank: 32 + loss_fn: cross_entropy + # ... see HatcheryConfig for full options + """ + + def get_input_args(self) -> str: + return "axolotl.integrations.hatchery.args.HatcheryArgs" + + def register(self, cfg: dict): + """Auto-set config values needed for remote training.""" + if cfg.get("remove_unused_columns") is None: + cfg["remove_unused_columns"] = False + + def pre_model_load(self, cfg: DictDefault): + """Replace model loading with a tiny stub.""" + hcfg = cfg.hatchery or {} + backend = ( + hcfg.get("backend", "tinker") + if isinstance(hcfg, dict) + else getattr(hcfg, "backend", "tinker") + ) + LOG.info( + f"Hatchery plugin active: training dispatched to remote " + f"{backend} API. Skipping local model weight loading." + ) + + from axolotl.loaders import ModelLoader + + def _stub_build_model(loader_self) -> bool: + base_model = loader_self.cfg.base_model + LOG.info(f"Skipping model weight loading for: {base_model}") + + config = AutoConfig.from_pretrained( + base_model, + trust_remote_code=loader_self.cfg.get("trust_remote_code", False), + ) + + class _Stub(PreTrainedModel): + config_class = type(config) + _no_split_modules: list[str] = [] + supports_gradient_checkpointing = False + + def __init__(self, cfg): + super().__init__(cfg) + vocab_size = getattr(cfg, "vocab_size", 32000) + self.embed_tokens = torch.nn.Embedding(vocab_size, 1) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + pass + + def get_output_embeddings(self): + return None + + loader_self.model = _Stub(config) + return True + + ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment] + + def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: + """Return the appropriate remote trainer class.""" + hcfg = cfg.hatchery + loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy" + + if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"): + from .rl_trainer import HatcheryRLTrainer + + return HatcheryRLTrainer + + from .trainer import HatcheryTrainer + + return HatcheryTrainer + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + model._hatchery_remote = True + + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + LOG.info( + "Hatchery: skipping local model save (weights are on remote API). " + "Use `tinker checkpoint download` or hatchery CLI to retrieve." + ) + + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Inject hatchery config + axolotl training params into the trainer.""" + from .args import HatcheryConfig + from .rl_trainer import HatcheryRLTrainer + from .trainer import HatcheryTrainer + + if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)): + return + + hcfg = cfg.hatchery + if isinstance(hcfg, dict): + hatchery_config = HatcheryConfig(**hcfg) + elif hcfg is None: + hatchery_config = HatcheryConfig() + else: + hatchery_config = hcfg + + trainer.hatchery_args = hatchery_config + trainer._base_model_name = cfg.base_model + + # Pull standard training params from axolotl config so they + # don't need to be duplicated under hatchery: + trainer._optim_params = { + "learning_rate": cfg.learning_rate + if cfg.learning_rate is not None + else 1e-4, + "beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9, + "beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95, + "eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12, + "weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0, + "grad_clip_norm": cfg.max_grad_norm + if cfg.max_grad_norm is not None + else 0.0, + } diff --git a/src/axolotl/integrations/hatchery/rewards/__init__.py b/src/axolotl/integrations/hatchery/rewards/__init__.py new file mode 100644 index 000000000..1cfe76e77 --- /dev/null +++ b/src/axolotl/integrations/hatchery/rewards/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 diff --git a/src/axolotl/integrations/hatchery/rewards/math_reward.py b/src/axolotl/integrations/hatchery/rewards/math_reward.py new file mode 100644 index 000000000..970353b0b --- /dev/null +++ b/src/axolotl/integrations/hatchery/rewards/math_reward.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Math reward function for hendrycks_math GRPO training. + +Uses math_verify for robust answer comparison. Falls back to +exact string match of \\boxed{} content only when math_verify +is unavailable. +""" + +from __future__ import annotations + +import logging +import re + +LOG = logging.getLogger(__name__) + + +def extract_boxed(text: str) -> str | None: + """Extract \\boxed{...} answer handling nested braces.""" + match = re.search(r"\\boxed\{", text) + if not match: + return None + start = match.end() + depth = 1 + i = start + while i < len(text) and depth > 0: + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + i += 1 + return text[start : i - 1] if depth == 0 else None + + +def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: + """Score completions by checking if \\boxed{} answer matches the gold answer. + + The gold answer is extracted from the prompt (appended as a hidden + tag by the dataset preprocessing). Format: + ... <|gold|>ANSWER<|/gold|> + """ + rewards = [] + for prompt, completion in zip(prompts, completions, strict=True): + gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt) + if not gold_match: + rewards.append(0.0) + continue + + gold_answer = gold_match.group(1).strip() + pred_answer = extract_boxed(completion) + + if pred_answer is None: + rewards.append(0.0) + continue + + verified = None + try: + from math_verify import parse, verify + + gold_parsed = parse(gold_answer) + pred_parsed = parse(pred_answer) + verified = verify(gold_parsed, pred_parsed) + except Exception: + LOG.debug( + "math_verify unavailable or failed, using string fallback", + exc_info=True, + ) + + if verified is not None: + rewards.append(1.0 if verified else 0.0) + elif pred_answer.strip() == gold_answer.strip(): + rewards.append(1.0) + else: + rewards.append(0.0) + + return rewards diff --git a/src/axolotl/integrations/hatchery/rl_trainer.py b/src/axolotl/integrations/hatchery/rl_trainer.py new file mode 100644 index 000000000..fc6d32d6a --- /dev/null +++ b/src/axolotl/integrations/hatchery/rl_trainer.py @@ -0,0 +1,409 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API. + +Full RL loop per step: + 1. Extract prompts from dataset batch + 2. Sample N completions per prompt via remote SamplingClient + 3. Score completions with local reward functions + 4. Compute GRPO-style advantages (per-group normalization) + 5. Send (prompt+completion, logprobs, advantages) as forward_backward + 6. Optimizer step +""" + +from __future__ import annotations + +import importlib +import inspect +import re +import time +from typing import Any, Callable, Optional + +import torch +from transformers.trainer_utils import TrainOutput + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.logging import get_logger + +from .args import HatcheryConfig +from .data import batch_to_datums_rl, datums_to_tinker +from .trainer import _create_training_client + +LOG = get_logger(__name__) + + +def _load_reward_func(fqn: str) -> Callable: + """Load a reward function from a fully qualified name like 'module.func'.""" + module_path = ".".join(fqn.split(".")[:-1]) + func_name = fqn.split(".")[-1] + mod = importlib.import_module(module_path) + func = getattr(mod, func_name) + if len(inspect.signature(func).parameters) < 2: + raise ValueError(f"Reward function {fqn} must accept (prompts, completions)") + return func + + +class HatcheryRLTrainer(AxolotlTrainer): + """Remote RL trainer using Tinker/Hatchery for sampling and training.""" + + hatchery_args: Optional[HatcheryConfig] + _base_model_name: Optional[str] + _training_client: Any + _reward_functions: list[Callable] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hatchery_args = None + self._base_model_name = None + self._training_client = None + self._reward_functions = [] + + def _ensure_reward_functions(self): + if self._reward_functions: + return + args = self.hatchery_args + if not args or not args.reward_funcs: + raise ValueError( + "No reward functions configured. Set hatchery.reward_funcs " + "in YAML, e.g. reward_funcs: ['my_module.my_reward']" + ) + for fqn in args.reward_funcs: + self._reward_functions.append(_load_reward_func(fqn)) + LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)") + + def _get_training_client(self): + if self._training_client is not None: + return self._training_client + + self._training_client = _create_training_client( + self.hatchery_args, self._base_model_name + ) + LOG.info( + f"Remote RL session created: backend={self.hatchery_args.backend}, " + f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}" + ) + return self._training_client + + def _sample_completions(self, prompt_ids_list: list[list[int]]): + """Sample completions for prompts via remote API.""" + import tinker.types as tt + + tc = self._get_training_client() + args = self.hatchery_args + assert args is not None # validated by _get_training_client + results = [] + + sc = tc.save_weights_and_get_sampling_client() + + for prompt_ids in prompt_ids_list: + if hasattr(sc, "sampling_session_id"): + sample_result = sc.sample( + prompt_ids, + max_tokens=args.max_sample_tokens, + temperature=args.sample_temperature, + n=args.num_samples, + ).result(timeout=args.future_timeout) + else: + mi = tt.ModelInput.from_ints(prompt_ids) + sp = tt.SamplingParams( + max_tokens=args.max_sample_tokens, + temperature=args.sample_temperature, + top_p=0.95, + top_k=-1, + ) + sample_result = sc.sample( + prompt=mi, + num_samples=args.num_samples, + sampling_params=sp, + ).result(timeout=args.future_timeout) + + sequences = ( + sample_result.sequences + if hasattr(sample_result, "sequences") + else sample_result.get("sequences", []) + ) + for seq in sequences: + tokens = ( + list(seq.tokens) + if hasattr(seq, "tokens") + else seq.get("tokens", []) + ) + logprobs = ( + list(seq.logprobs) + if hasattr(seq, "logprobs") and seq.logprobs + else seq.get("logprobs", []) + ) + results.append( + { + "tokens": list(prompt_ids) + tokens, + "completion_tokens": tokens, + "logprobs": logprobs, + "prompt_len": len(prompt_ids), + } + ) + + return results + + def _compute_rewards( + self, prompts: list[str], completions: list[str] + ) -> list[float]: + total_rewards = [0.0] * len(completions) + for reward_fn in self._reward_functions: + rewards = reward_fn(prompts, completions) + for i, r in enumerate(rewards): + total_rewards[i] += r + return total_rewards + + @staticmethod + def _compute_advantages(rewards: list[float], group_size: int) -> list[float]: + advantages = [] + for i in range(0, len(rewards), group_size): + group = rewards[i : i + group_size] + mean = sum(group) / len(group) + var = sum((r - mean) ** 2 for r in group) / max(len(group), 1) + std = var**0.5 if var > 1e-8 else 1.0 + advantages.extend([(r - mean) / std for r in group]) + return advantages + + def _do_optim_step(self): + import tinker.types as tt + + tc = self._get_training_client() + return tc.optim_step(tt.AdamParams(**self._optim_params)) + + def train( + self, + resume_from_checkpoint: Optional[str] = None, + trial: Any = None, + ignore_keys_for_eval: Optional[list[str]] = None, + **kwargs, + ) -> TrainOutput: + args = self.hatchery_args + if args is None: + raise RuntimeError("hatchery_args not configured") + + self._ensure_reward_functions() + + train_dataloader = self.get_train_dataloader() + num_train_epochs = int(self.args.num_train_epochs) + max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000 + + LOG.info( + f"Remote RL training: max_steps={max_steps}, " + f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}" + ) + + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = True + self.state.is_world_process_zero = True + + self.control = self.callback_handler.on_train_begin( + self.args, + self.state, + self.control, # type: ignore[has-type] + ) + + tokenizer = self.processing_class + global_step = 0 + total_loss = 0.0 + total_reward = 0.0 + start_time = time.time() + + for _epoch in range(num_train_epochs): + if global_step >= max_steps: + break + + for batch in train_dataloader: + if global_step >= max_steps: + break + + self.control = self.callback_handler.on_step_begin( + self.args, self.state, self.control + ) + + prompt_ids_batch = batch["input_ids"] + # Full prompt text (with gold tag) for reward scoring + prompt_texts = tokenizer.batch_decode( + prompt_ids_batch, skip_special_tokens=False + ) + + # Strip <|gold|>...<|/gold|> from token ids before + # sending to the model for sampling — the gold answer + # must only be visible to the local reward function. + sampling_prompts = [] + for prompt_text in prompt_texts: + clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text) + clean_ids = tokenizer.encode(clean, add_special_tokens=False) + sampling_prompts.append(clean_ids) + + # 1. Sample completions (without gold answer) + t0 = time.time() + samples = self._sample_completions(sampling_prompts) + t_sample = time.time() - t0 + + if not samples: + LOG.warning("No samples generated, skipping step") + continue + LOG.info( + f"Sampled {len(samples)} completions, " + f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok" + ) + + # 2. Decode and score + completion_texts = [ + tokenizer.decode(s["completion_tokens"], skip_special_tokens=False) + for s in samples + ] + sample_prompts = [] + for prompt_text in prompt_texts: + sample_prompts.extend([prompt_text] * args.num_samples) + + rewards = self._compute_rewards(sample_prompts, completion_texts) + + # 3. GRPO advantages + advantages_list = self._compute_advantages( + rewards, group_size=args.num_samples + ) + + # 4. Build training data + all_datums = [] + for i, sample in enumerate(samples): + full_tokens = sample["tokens"] + prompt_len = sample["prompt_len"] + seq_len = len(full_tokens) + + input_ids = torch.tensor([full_tokens], dtype=torch.long) + labels = torch.full((1, seq_len), -100, dtype=torch.long) + labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:]) + + logprobs_t = torch.zeros(1, seq_len) + if sample["logprobs"]: + lp = sample["logprobs"][: seq_len - prompt_len] + logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor( + lp + ) + + adv_t = torch.zeros(1, seq_len) + adv_t[0, prompt_len:] = advantages_list[i] + + all_datums.extend( + batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t) + ) + + # 5. Forward backward (one datum at a time for memory) + optim + t0 = time.time() + tc = self._get_training_client() + step_loss = 0.0 + for datum in all_datums: + fb_future = tc.forward_backward( + datums_to_tinker([datum]), + loss_fn=args.loss_fn, + loss_fn_config=args.loss_fn_config, + ) + fb_result = fb_future.result(timeout=args.future_timeout) + if hasattr(fb_result, "metrics"): + step_loss += float( + (fb_result.metrics or {}).get("loss:sum", 0.0) + ) + elif isinstance(fb_result, dict): + step_loss += float( + fb_result.get("metrics", {}).get("loss:sum", 0.0) + ) + optim_future = self._do_optim_step() + if not args.pipeline: + optim_future.result(timeout=args.future_timeout) + t_train = time.time() - t0 + + mean_reward = sum(rewards) / len(rewards) + accuracy = sum(1 for r in rewards if r > 0) / len(rewards) + mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list) + global_step += 1 + total_loss += step_loss + total_reward += mean_reward + self.state.global_step = global_step + + log_interval = self.args.logging_steps or 1 + if global_step % log_interval == 0: + elapsed = time.time() - start_time + LOG.info( + f"[step {global_step}/{max_steps}] " + f"acc={accuracy:.2f} reward={mean_reward:.3f} " + f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} " + f"sample={t_sample:.1f}s train={t_train:.1f}s " + f"{elapsed / global_step:.1f}s/step" + ) + self.log( + { + "loss": step_loss, + "reward": mean_reward, + "accuracy": accuracy, + "mean_abs_advantage": mean_adv, + "learning_rate": self._optim_params["learning_rate"], + } + ) + + if args.save_steps and global_step % args.save_steps == 0: + self._save_remote_checkpoint(global_step) + + self.control = self.callback_handler.on_step_end( + self.args, self.state, self.control + ) + if self.control.should_training_stop: + break + + if self.control.should_training_stop: + break + + if global_step > 0: + self._save_remote_checkpoint(global_step, name="final") + + elapsed = time.time() - start_time + avg_loss = total_loss / max(global_step, 1) + avg_reward = total_reward / max(global_step, 1) + + LOG.info( + f"RL training complete: {global_step} steps, {elapsed:.1f}s, " + f"avg_reward={avg_reward:.4f}" + ) + + self.control = self.callback_handler.on_train_end( + self.args, self.state, self.control + ) + + return TrainOutput( + global_step=global_step, + training_loss=avg_loss, + metrics={ + "train_loss": avg_loss, + "train_reward": avg_reward, + "train_runtime": elapsed, + }, + ) + + def _save_remote_checkpoint(self, step: int, name: Optional[str] = None): + tc = self._get_training_client() + args = self.hatchery_args + assert args is not None # validated by _get_training_client + ckpt_name = name or f"{args.save_name_prefix}-{step:06d}" + try: + future = tc.save_state(ckpt_name) + future.result(timeout=args.future_timeout) + LOG.info(f"Remote checkpoint saved: {ckpt_name}") + except Exception: + LOG.exception(f"Failed to save checkpoint {ckpt_name}") + if name == "final": + raise + + def save_model(self, output_dir=None, _internal_call=False): + self._save_remote_checkpoint( + step=self.state.global_step, + name=output_dir or "hf-save", + ) + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + raise NotImplementedError( + "HatcheryRLTrainer uses remote API; compute_loss not called locally." + ) diff --git a/src/axolotl/integrations/hatchery/trainer.py b/src/axolotl/integrations/hatchery/trainer.py new file mode 100644 index 000000000..4eb632db3 --- /dev/null +++ b/src/axolotl/integrations/hatchery/trainer.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Remote trainer that dispatches to Tinker or Hatchery API.""" + +from __future__ import annotations + +import os +import time +from typing import Any, Optional + +import torch +from transformers.trainer_utils import TrainOutput + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.logging import get_logger + +from .args import HatcheryConfig +from .data import batch_to_datums_sft, datums_to_tinker + +LOG = get_logger(__name__) + + +def _extract_loss(result) -> float: + """Extract loss:sum from a forward_backward result. + + Tinker's cross_entropy (and other losses) return the SUM of per-token + losses, not the mean. This is by design — it lets users control + normalization via the weights tensor. The trainer logs this raw sum; + users who want per-token loss should divide by number of active tokens. + """ + if hasattr(result, "metrics"): + metrics = result.metrics or {} + return float(metrics.get("loss:sum", metrics.get("loss", 0.0))) + if isinstance(result, dict): + metrics = result.get("metrics", {}) + return float(metrics.get("loss:sum", metrics.get("loss", 0.0))) + return 0.0 + + +def _create_training_client(args: HatcheryConfig, base_model: str): + """Create a training client for either Tinker or Hatchery backend.""" + if args.backend == "tinker": + import tinker + + api_key = args.api_key or os.environ.get("TINKER_API_KEY") + if not api_key: + raise ValueError( + "Tinker API key required. Set `hatchery.api_key` in config " + "or TINKER_API_KEY env var." + ) + os.environ["TINKER_API_KEY"] = api_key + + service = tinker.ServiceClient(project_id=args.project_id) + return service.create_lora_training_client( + base_model=base_model, + rank=args.lora_rank, + train_mlp=args.train_mlp, + train_attn=args.train_attn, + train_unembed=args.train_unembed, + ) + + from hatchery.core.client import HatcheryClient + + base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420") + token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev") + + client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout) + return client.create_lora_training_client( + base_model=base_model, + rank=args.lora_rank, + train_attn=args.train_attn, + train_mlp=args.train_mlp, + train_unembed=args.train_unembed, + ) + + +class HatcheryTrainer(AxolotlTrainer): + """Trainer that sends preprocessed batches to a remote training API. + + Replaces local forward/backward with remote API calls to Tinker or + Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization, + chat templates, packing, etc.) but offloads compute to remote GPUs. + """ + + hatchery_args: Optional[HatcheryConfig] + _base_model_name: Optional[str] + _training_client: Any + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hatchery_args = None + self._base_model_name = None + self._training_client = None + + def _get_training_client(self): + """Lazily create the remote training session.""" + if self._training_client is not None: + return self._training_client + + args = self.hatchery_args + if args is None: + raise RuntimeError( + "HatcheryTrainer.hatchery_args not set. " + "Ensure the HatcheryPlugin is registered." + ) + + base_model = self._base_model_name + if not base_model: + raise RuntimeError("HatcheryTrainer._base_model_name not set.") + + self._training_client = _create_training_client(args, base_model) + + LOG.info( + f"Remote training session created: backend={args.backend}, " + f"model={base_model}, rank={args.lora_rank}" + ) + return self._training_client + + def _send_batch(self, batch: dict[str, torch.Tensor]): + """Convert batch to datums and send forward_backward to remote. + + Returns (future, n_active_tokens) where n_active_tokens counts + the completion tokens in this batch (for loss normalization). + """ + input_ids = batch["input_ids"] + labels = batch["labels"] + attention_mask = batch.get("attention_mask") + + n_active = int((labels[:, 1:] != -100).sum().item()) + datums = batch_to_datums_sft(input_ids, labels, attention_mask) + + tc = self._get_training_client() + args = self.hatchery_args + assert args is not None # validated by _get_training_client + send_datums = datums_to_tinker(datums) + + future = tc.forward_backward( + send_datums, + loss_fn=args.loss_fn, + loss_fn_config=args.loss_fn_config, + ) + return future, n_active + + def _do_optim_step(self): + """Send optimizer step to remote using axolotl's training params.""" + import tinker.types as tt + + tc = self._get_training_client() + return tc.optim_step(tt.AdamParams(**self._optim_params)) + + def train( + self, + resume_from_checkpoint: Optional[str] = None, + trial: Any = None, + ignore_keys_for_eval: Optional[list[str]] = None, + **kwargs, + ) -> TrainOutput: + """Main training loop — sends batches to remote API.""" + args = self.hatchery_args + if args is None: + raise RuntimeError("hatchery_args not configured") + + train_dataloader = self.get_train_dataloader() + num_batches = len(train_dataloader) + + grad_accum = self.args.gradient_accumulation_steps + num_train_epochs = int(self.args.num_train_epochs) + steps_per_epoch = max(num_batches // grad_accum, 1) + max_steps = ( + self.args.max_steps + if self.args.max_steps > 0 + else steps_per_epoch * num_train_epochs + ) + + LOG.info( + f"Remote training: {num_batches} batches/epoch, " + f"{grad_accum} grad_accum, {max_steps} max steps, " + f"{num_train_epochs} epochs" + ) + + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = True + self.state.is_world_process_zero = True + + self.control = self.callback_handler.on_train_begin( + self.args, + self.state, + self.control, # type: ignore[has-type] + ) + + global_step = 0 + total_loss = 0.0 + start_time = time.time() + + for _epoch in range(num_train_epochs): + if global_step >= max_steps: + break + + self.control = self.callback_handler.on_epoch_begin( + self.args, self.state, self.control + ) + + pending_fb_futures = [] + accum_count = 0 + + for batch_idx, batch in enumerate(train_dataloader): + if global_step >= max_steps: + break + + self.control = self.callback_handler.on_step_begin( + self.args, self.state, self.control + ) + + fb_future, n_active = self._send_batch(batch) + pending_fb_futures.append((fb_future, n_active)) + accum_count += 1 + + if accum_count >= grad_accum: + step_loss_sum = 0.0 + step_active = 0 + for fut, n_act in pending_fb_futures: + result = fut.result(timeout=args.future_timeout) + step_loss_sum += _extract_loss(result) + step_active += n_act + + optim_future = self._do_optim_step() + if not args.pipeline: + optim_future.result(timeout=args.future_timeout) + + step_loss = ( + step_loss_sum / step_active + if step_active > 0 + else step_loss_sum + ) + + global_step += 1 + total_loss += step_loss + self.state.global_step = global_step + self.state.epoch = _epoch + (batch_idx + 1) / num_batches + + log_interval = self.args.logging_steps or 1 + if global_step % log_interval == 0: + elapsed = time.time() - start_time + avg_loss = total_loss / global_step + LOG.info( + f"[step {global_step}/{max_steps}] " + f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} " + f"active={step_active} " + f"{elapsed / global_step:.2f}s/step" + ) + self.log( + { + "loss": step_loss, + "learning_rate": self._optim_params["learning_rate"], + "epoch": self.state.epoch, + } + ) + + if args.save_steps and global_step % args.save_steps == 0: + self._save_remote_checkpoint(global_step) + + self.control = self.callback_handler.on_step_end( + self.args, self.state, self.control + ) + + pending_fb_futures = [] + accum_count = 0 + + if self.control.should_training_stop: + break + + self.control = self.callback_handler.on_epoch_end( + self.args, self.state, self.control + ) + if self.control.should_training_stop: + break + + if global_step > 0: + self._save_remote_checkpoint(global_step, name="final") + + elapsed = time.time() - start_time + avg_loss = total_loss / max(global_step, 1) + + LOG.info( + f"Training complete: {global_step} steps, {elapsed:.1f}s total, " + f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}" + ) + + self.control = self.callback_handler.on_train_end( + self.args, self.state, self.control + ) + + return TrainOutput( + global_step=global_step, + training_loss=avg_loss, + metrics={"train_loss": avg_loss, "train_runtime": elapsed}, + ) + + def _save_remote_checkpoint(self, step: int, name: Optional[str] = None): + """Save a checkpoint on the remote service.""" + tc = self._get_training_client() + args = self.hatchery_args + assert args is not None # validated by _get_training_client + ckpt_name = name or f"{args.save_name_prefix}-{step:06d}" + try: + future = tc.save_state(ckpt_name) + future.result(timeout=args.future_timeout) + LOG.info(f"Remote checkpoint saved: {ckpt_name}") + except Exception: + LOG.exception(f"Failed to save checkpoint {ckpt_name}") + if name == "final": + raise + + def save_model(self, output_dir=None, _internal_call=False): + """Delegate to remote checkpoint save so HF callbacks create checkpoints.""" + self._save_remote_checkpoint( + step=self.state.global_step, + name=output_dir or "hf-save", + ) + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + raise NotImplementedError( + "HatcheryTrainer uses remote API; compute_loss should not be called." + )