train on remote compute using Tinker compatible APIs (#3614)

* train on remote compute using Tinker compatible APIs

* chore: lint

* fixes with latest hatchery changes

* chore: lint
This commit is contained in:
Wing Lian
2026-04-22 01:14:41 -04:00
committed by GitHub
parent e562e149ce
commit 05113bc91a
11 changed files with 1389 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Hatchery/Tinker remote training integration for Axolotl.
Routes axolotl's preprocessed data to a remote training API (Tinker or
Hatchery) instead of running forward/backward locally. The remote
service handles model weights, LoRA adapters, and gradient updates.
"""
from .args import HatcheryArgs, HatcheryConfig
from .plugin import HatcheryPlugin
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
# Usage:
# plugins:
# - axolotl.integrations.hatchery.HatcheryPlugin
#
# hatchery:
# backend: tinker # or "hatchery"
# lora_rank: 32
# loss_fn: cross_entropy # SFT
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
#
# learning_rate: 1e-4 # top-level, not under hatchery:

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Pydantic config schema for the Hatchery integration."""
from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
class HatcheryConfig(BaseModel):
"""Nested config under `hatchery:` in the axolotl YAML.
Only contains hatchery-specific settings. Standard training params
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
gradient_accumulation_steps) are read from axolotl's top-level config.
"""
# Backend & connection
backend: Literal["tinker", "hatchery"] = "tinker"
base_url: Optional[str] = None
api_key: Optional[str] = None
project_id: Optional[str] = None
# LoRA config sent to remote
lora_rank: int = Field(32, ge=1, le=256)
train_attn: bool = True
train_mlp: bool = True
train_unembed: bool = True
# Loss function
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
"cross_entropy"
)
loss_fn_config: Optional[dict[str, Any]] = None
# Pipelining: submit next batch before awaiting previous result
pipeline: bool = True
# Sampling params (for RL flows)
max_sample_tokens: int = 256
sample_temperature: float = 1.0
num_samples: int = 4
# Reward functions (for RL) — list of fully qualified names
reward_funcs: Optional[list[str]] = None
# Checkpointing
save_steps: Optional[int] = None
save_name_prefix: str = "checkpoint"
# Timeout per future (seconds)
future_timeout: float = 600.0
class HatcheryArgs(BaseModel):
"""Top-level mixin that adds the nested `hatchery:` field."""
hatchery: Optional[HatcheryConfig] = None

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
Both Tinker and Hatchery expect the client to apply the causal LM shift:
Original tokens: [t0, t1, t2, ..., t_{L-1}]
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
"""
from __future__ import annotations
from typing import Any
import torch
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
"""Serialize a tensor to the TensorData wire dict."""
flat = t.detach().cpu().flatten()
dtype_map = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
}
return {
"dtype": dtype_map.get(flat.dtype, "float32"),
"shape": list(t.shape),
"data": flat.tolist(),
}
def _make_datum(
tokens: list[int],
loss_fn_inputs: dict[str, torch.Tensor],
) -> dict[str, Any]:
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
return {
"model_input": {
"chunks": [{"type": "encoded_text", "tokens": tokens}],
},
"loss_fn_inputs": {
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
},
}
def datums_to_tinker(datums: list[dict[str, Any]]):
"""Wrap plain-dict datums into tinker.types.Datum objects.
Both the Tinker SDK and updated Hatchery client accept these.
"""
import tinker.types as tt
result = []
for d in datums:
tokens = d["model_input"]["chunks"][0]["tokens"]
tinker_inputs = {}
for key, wire in d["loss_fn_inputs"].items():
tinker_inputs[key] = tt.TensorData(
data=wire["data"],
dtype=wire["dtype"],
shape=wire["shape"],
)
result.append(
tt.Datum(
model_input=tt.ModelInput.from_ints(tokens),
loss_fn_inputs=tinker_inputs,
)
)
return result
def batch_to_datums_sft(
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
ids = ids[:seq_len]
lbl = lbl[:seq_len]
model_tokens = ids[:-1].tolist()
shifted_labels = lbl[1:]
target_tokens = shifted_labels.clone()
weights = (shifted_labels != -100).float()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"weights": weights,
},
)
)
return datums
def batch_to_datums_rl(
input_ids: torch.Tensor,
labels: torch.Tensor,
logprobs: torch.Tensor,
advantages: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
else:
seq_len = ids.size(0)
ids = ids[:seq_len]
lbl = lbl[:seq_len]
lp = logprobs[i, :seq_len]
adv = advantages[i, :seq_len]
model_tokens = ids[:-1].tolist()
target_tokens = lbl[1:].clone()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"logprobs": lp[1:],
"advantages": adv[1:],
},
)
)
return datums

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
Creates a dataset with chat-formatted prompts that include
a hidden gold answer tag for the reward function.
Run:
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
"""
import os
import re
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
def extract_boxed(text: str) -> str:
match = re.search(r"\\boxed\{", text)
if not match:
return ""
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else ""
def main():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
level = os.environ.get("MATH_LEVEL", "Level 1")
filtered_rows = [x for x in ds if x["level"] == level]
print(f"{level} algebra: {len(filtered_rows)} problems")
rows = []
for prob in filtered_rows:
gold = extract_boxed(prob["solution"])
if not gold:
continue
# Format as chat prompt with hidden gold tag
prompt = (
f"Solve the following math problem. "
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
f"{prob['problem']}"
f"<|gold|>{gold}<|/gold|>"
)
# Tokenize the prompt
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
rows.append(
{
"input_ids": prompt_ids,
"labels": [-100] * len(prompt_ids),
"attention_mask": [1] * len(prompt_ids),
}
)
out = Dataset.from_list(rows)
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
out.save_to_disk(out_dir)
print(f"Saved {len(out)} examples to {out_dir}")
if rows:
print(
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
f"-{max(len(r['input_ids']) for r in rows)}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
#
# Prep:
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
#
# Run:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: importance_sampling
max_sample_tokens: 2048
sample_temperature: 0.7
num_samples: 4
pipeline: true
save_steps: 5
reward_funcs:
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
datasets:
- path: ./data/math_rl_level1
ds_type: arrow
type: completion
sequence_len: 2048
learning_rate: 5.0e-5
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
max_steps: 10
num_epochs: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-rl-math

View File

@@ -0,0 +1,42 @@
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
#
# Usage:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: cross_entropy
pipeline: true
save_steps: 10
datasets:
- path: TeichAI/kimi-k2-thinking-1000x
split: train[:50]
type: chat_template
chat_template: qwen3
split_thinking: true
chat_template: qwen3
sequence_len: 2048
learning_rate: 3.0e-4
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
num_epochs: 1
max_steps: 20
micro_batch_size: 2
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-sft

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
from __future__ import annotations
import torch
from peft import PeftModel
from transformers import AutoConfig, PreTrainedModel, Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HatcheryPlugin(BasePlugin):
"""Plugin that replaces local training with remote API calls.
Activated by adding to the axolotl YAML:
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker # or "hatchery"
lora_rank: 32
loss_fn: cross_entropy
# ... see HatcheryConfig for full options
"""
def get_input_args(self) -> str:
return "axolotl.integrations.hatchery.args.HatcheryArgs"
def register(self, cfg: dict):
"""Auto-set config values needed for remote training."""
if cfg.get("remove_unused_columns") is None:
cfg["remove_unused_columns"] = False
def pre_model_load(self, cfg: DictDefault):
"""Replace model loading with a tiny stub."""
hcfg = cfg.hatchery or {}
backend = (
hcfg.get("backend", "tinker")
if isinstance(hcfg, dict)
else getattr(hcfg, "backend", "tinker")
)
LOG.info(
f"Hatchery plugin active: training dispatched to remote "
f"{backend} API. Skipping local model weight loading."
)
from axolotl.loaders import ModelLoader
def _stub_build_model(loader_self) -> bool:
base_model = loader_self.cfg.base_model
LOG.info(f"Skipping model weight loading for: {base_model}")
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
)
class _Stub(PreTrainedModel):
config_class = type(config)
_no_split_modules: list[str] = []
supports_gradient_checkpointing = False
def __init__(self, cfg):
super().__init__(cfg)
vocab_size = getattr(cfg, "vocab_size", 32000)
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
pass
def get_output_embeddings(self):
return None
loader_self.model = _Stub(config)
return True
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Return the appropriate remote trainer class."""
hcfg = cfg.hatchery
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
from .rl_trainer import HatcheryRLTrainer
return HatcheryRLTrainer
from .trainer import HatcheryTrainer
return HatcheryTrainer
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
model._hatchery_remote = True
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
LOG.info(
"Hatchery: skipping local model save (weights are on remote API). "
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
)
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Inject hatchery config + axolotl training params into the trainer."""
from .args import HatcheryConfig
from .rl_trainer import HatcheryRLTrainer
from .trainer import HatcheryTrainer
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
return
hcfg = cfg.hatchery
if isinstance(hcfg, dict):
hatchery_config = HatcheryConfig(**hcfg)
elif hcfg is None:
hatchery_config = HatcheryConfig()
else:
hatchery_config = hcfg
trainer.hatchery_args = hatchery_config
trainer._base_model_name = cfg.base_model
# Pull standard training params from axolotl config so they
# don't need to be duplicated under hatchery:
trainer._optim_params = {
"learning_rate": cfg.learning_rate
if cfg.learning_rate is not None
else 1e-4,
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
"grad_clip_norm": cfg.max_grad_norm
if cfg.max_grad_norm is not None
else 0.0,
}

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Math reward function for hendrycks_math GRPO training.
Uses math_verify for robust answer comparison. Falls back to
exact string match of \\boxed{} content only when math_verify
is unavailable.
"""
from __future__ import annotations
import logging
import re
LOG = logging.getLogger(__name__)
def extract_boxed(text: str) -> str | None:
"""Extract \\boxed{...} answer handling nested braces."""
match = re.search(r"\\boxed\{", text)
if not match:
return None
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else None
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
"""Score completions by checking if \\boxed{} answer matches the gold answer.
The gold answer is extracted from the prompt (appended as a hidden
tag by the dataset preprocessing). Format:
... <|gold|>ANSWER<|/gold|>
"""
rewards = []
for prompt, completion in zip(prompts, completions, strict=True):
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
if not gold_match:
rewards.append(0.0)
continue
gold_answer = gold_match.group(1).strip()
pred_answer = extract_boxed(completion)
if pred_answer is None:
rewards.append(0.0)
continue
verified = None
try:
from math_verify import parse, verify
gold_parsed = parse(gold_answer)
pred_parsed = parse(pred_answer)
verified = verify(gold_parsed, pred_parsed)
except Exception:
LOG.debug(
"math_verify unavailable or failed, using string fallback",
exc_info=True,
)
if verified is not None:
rewards.append(1.0 if verified else 0.0)
elif pred_answer.strip() == gold_answer.strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards

View File

@@ -0,0 +1,409 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
Full RL loop per step:
1. Extract prompts from dataset batch
2. Sample N completions per prompt via remote SamplingClient
3. Score completions with local reward functions
4. Compute GRPO-style advantages (per-group normalization)
5. Send (prompt+completion, logprobs, advantages) as forward_backward
6. Optimizer step
"""
from __future__ import annotations
import importlib
import inspect
import re
import time
from typing import Any, Callable, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_rl, datums_to_tinker
from .trainer import _create_training_client
LOG = get_logger(__name__)
def _load_reward_func(fqn: str) -> Callable:
"""Load a reward function from a fully qualified name like 'module.func'."""
module_path = ".".join(fqn.split(".")[:-1])
func_name = fqn.split(".")[-1]
mod = importlib.import_module(module_path)
func = getattr(mod, func_name)
if len(inspect.signature(func).parameters) < 2:
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
return func
class HatcheryRLTrainer(AxolotlTrainer):
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
_reward_functions: list[Callable]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
self._reward_functions = []
def _ensure_reward_functions(self):
if self._reward_functions:
return
args = self.hatchery_args
if not args or not args.reward_funcs:
raise ValueError(
"No reward functions configured. Set hatchery.reward_funcs "
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
)
for fqn in args.reward_funcs:
self._reward_functions.append(_load_reward_func(fqn))
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
def _get_training_client(self):
if self._training_client is not None:
return self._training_client
self._training_client = _create_training_client(
self.hatchery_args, self._base_model_name
)
LOG.info(
f"Remote RL session created: backend={self.hatchery_args.backend}, "
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
)
return self._training_client
def _sample_completions(self, prompt_ids_list: list[list[int]]):
"""Sample completions for prompts via remote API."""
import tinker.types as tt
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
results = []
sc = tc.save_weights_and_get_sampling_client()
for prompt_ids in prompt_ids_list:
if hasattr(sc, "sampling_session_id"):
sample_result = sc.sample(
prompt_ids,
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
n=args.num_samples,
).result(timeout=args.future_timeout)
else:
mi = tt.ModelInput.from_ints(prompt_ids)
sp = tt.SamplingParams(
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
top_p=0.95,
top_k=-1,
)
sample_result = sc.sample(
prompt=mi,
num_samples=args.num_samples,
sampling_params=sp,
).result(timeout=args.future_timeout)
sequences = (
sample_result.sequences
if hasattr(sample_result, "sequences")
else sample_result.get("sequences", [])
)
for seq in sequences:
tokens = (
list(seq.tokens)
if hasattr(seq, "tokens")
else seq.get("tokens", [])
)
logprobs = (
list(seq.logprobs)
if hasattr(seq, "logprobs") and seq.logprobs
else seq.get("logprobs", [])
)
results.append(
{
"tokens": list(prompt_ids) + tokens,
"completion_tokens": tokens,
"logprobs": logprobs,
"prompt_len": len(prompt_ids),
}
)
return results
def _compute_rewards(
self, prompts: list[str], completions: list[str]
) -> list[float]:
total_rewards = [0.0] * len(completions)
for reward_fn in self._reward_functions:
rewards = reward_fn(prompts, completions)
for i, r in enumerate(rewards):
total_rewards[i] += r
return total_rewards
@staticmethod
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
advantages = []
for i in range(0, len(rewards), group_size):
group = rewards[i : i + group_size]
mean = sum(group) / len(group)
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
std = var**0.5 if var > 1e-8 else 1.0
advantages.extend([(r - mean) / std for r in group])
return advantages
def _do_optim_step(self):
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
self._ensure_reward_functions()
train_dataloader = self.get_train_dataloader()
num_train_epochs = int(self.args.num_train_epochs)
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
LOG.info(
f"Remote RL training: max_steps={max_steps}, "
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
tokenizer = self.processing_class
global_step = 0
total_loss = 0.0
total_reward = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
for batch in train_dataloader:
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
prompt_ids_batch = batch["input_ids"]
# Full prompt text (with gold tag) for reward scoring
prompt_texts = tokenizer.batch_decode(
prompt_ids_batch, skip_special_tokens=False
)
# Strip <|gold|>...<|/gold|> from token ids before
# sending to the model for sampling — the gold answer
# must only be visible to the local reward function.
sampling_prompts = []
for prompt_text in prompt_texts:
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
sampling_prompts.append(clean_ids)
# 1. Sample completions (without gold answer)
t0 = time.time()
samples = self._sample_completions(sampling_prompts)
t_sample = time.time() - t0
if not samples:
LOG.warning("No samples generated, skipping step")
continue
LOG.info(
f"Sampled {len(samples)} completions, "
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
)
# 2. Decode and score
completion_texts = [
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
for s in samples
]
sample_prompts = []
for prompt_text in prompt_texts:
sample_prompts.extend([prompt_text] * args.num_samples)
rewards = self._compute_rewards(sample_prompts, completion_texts)
# 3. GRPO advantages
advantages_list = self._compute_advantages(
rewards, group_size=args.num_samples
)
# 4. Build training data
all_datums = []
for i, sample in enumerate(samples):
full_tokens = sample["tokens"]
prompt_len = sample["prompt_len"]
seq_len = len(full_tokens)
input_ids = torch.tensor([full_tokens], dtype=torch.long)
labels = torch.full((1, seq_len), -100, dtype=torch.long)
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
logprobs_t = torch.zeros(1, seq_len)
if sample["logprobs"]:
lp = sample["logprobs"][: seq_len - prompt_len]
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
lp
)
adv_t = torch.zeros(1, seq_len)
adv_t[0, prompt_len:] = advantages_list[i]
all_datums.extend(
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
)
# 5. Forward backward (one datum at a time for memory) + optim
t0 = time.time()
tc = self._get_training_client()
step_loss = 0.0
for datum in all_datums:
fb_future = tc.forward_backward(
datums_to_tinker([datum]),
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
fb_result = fb_future.result(timeout=args.future_timeout)
if hasattr(fb_result, "metrics"):
step_loss += float(
(fb_result.metrics or {}).get("loss:sum", 0.0)
)
elif isinstance(fb_result, dict):
step_loss += float(
fb_result.get("metrics", {}).get("loss:sum", 0.0)
)
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
t_train = time.time() - t0
mean_reward = sum(rewards) / len(rewards)
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
global_step += 1
total_loss += step_loss
total_reward += mean_reward
self.state.global_step = global_step
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
LOG.info(
f"[step {global_step}/{max_steps}] "
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
f"sample={t_sample:.1f}s train={t_train:.1f}s "
f"{elapsed / global_step:.1f}s/step"
)
self.log(
{
"loss": step_loss,
"reward": mean_reward,
"accuracy": accuracy,
"mean_abs_advantage": mean_adv,
"learning_rate": self._optim_params["learning_rate"],
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
avg_reward = total_reward / max(global_step, 1)
LOG.info(
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
f"avg_reward={avg_reward:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={
"train_loss": avg_loss,
"train_reward": avg_reward,
"train_runtime": elapsed,
},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
)

View File

@@ -0,0 +1,327 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote trainer that dispatches to Tinker or Hatchery API."""
from __future__ import annotations
import os
import time
from typing import Any, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_sft, datums_to_tinker
LOG = get_logger(__name__)
def _extract_loss(result) -> float:
"""Extract loss:sum from a forward_backward result.
Tinker's cross_entropy (and other losses) return the SUM of per-token
losses, not the mean. This is by design — it lets users control
normalization via the weights tensor. The trainer logs this raw sum;
users who want per-token loss should divide by number of active tokens.
"""
if hasattr(result, "metrics"):
metrics = result.metrics or {}
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
if isinstance(result, dict):
metrics = result.get("metrics", {})
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
return 0.0
def _create_training_client(args: HatcheryConfig, base_model: str):
"""Create a training client for either Tinker or Hatchery backend."""
if args.backend == "tinker":
import tinker
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
if not api_key:
raise ValueError(
"Tinker API key required. Set `hatchery.api_key` in config "
"or TINKER_API_KEY env var."
)
os.environ["TINKER_API_KEY"] = api_key
service = tinker.ServiceClient(project_id=args.project_id)
return service.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_mlp=args.train_mlp,
train_attn=args.train_attn,
train_unembed=args.train_unembed,
)
from hatchery.core.client import HatcheryClient
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
return client.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_attn=args.train_attn,
train_mlp=args.train_mlp,
train_unembed=args.train_unembed,
)
class HatcheryTrainer(AxolotlTrainer):
"""Trainer that sends preprocessed batches to a remote training API.
Replaces local forward/backward with remote API calls to Tinker or
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
chat templates, packing, etc.) but offloads compute to remote GPUs.
"""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
def _get_training_client(self):
"""Lazily create the remote training session."""
if self._training_client is not None:
return self._training_client
args = self.hatchery_args
if args is None:
raise RuntimeError(
"HatcheryTrainer.hatchery_args not set. "
"Ensure the HatcheryPlugin is registered."
)
base_model = self._base_model_name
if not base_model:
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
self._training_client = _create_training_client(args, base_model)
LOG.info(
f"Remote training session created: backend={args.backend}, "
f"model={base_model}, rank={args.lora_rank}"
)
return self._training_client
def _send_batch(self, batch: dict[str, torch.Tensor]):
"""Convert batch to datums and send forward_backward to remote.
Returns (future, n_active_tokens) where n_active_tokens counts
the completion tokens in this batch (for loss normalization).
"""
input_ids = batch["input_ids"]
labels = batch["labels"]
attention_mask = batch.get("attention_mask")
n_active = int((labels[:, 1:] != -100).sum().item())
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
send_datums = datums_to_tinker(datums)
future = tc.forward_backward(
send_datums,
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
return future, n_active
def _do_optim_step(self):
"""Send optimizer step to remote using axolotl's training params."""
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
"""Main training loop — sends batches to remote API."""
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
train_dataloader = self.get_train_dataloader()
num_batches = len(train_dataloader)
grad_accum = self.args.gradient_accumulation_steps
num_train_epochs = int(self.args.num_train_epochs)
steps_per_epoch = max(num_batches // grad_accum, 1)
max_steps = (
self.args.max_steps
if self.args.max_steps > 0
else steps_per_epoch * num_train_epochs
)
LOG.info(
f"Remote training: {num_batches} batches/epoch, "
f"{grad_accum} grad_accum, {max_steps} max steps, "
f"{num_train_epochs} epochs"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
global_step = 0
total_loss = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_epoch_begin(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
for batch_idx, batch in enumerate(train_dataloader):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
fb_future, n_active = self._send_batch(batch)
pending_fb_futures.append((fb_future, n_active))
accum_count += 1
if accum_count >= grad_accum:
step_loss_sum = 0.0
step_active = 0
for fut, n_act in pending_fb_futures:
result = fut.result(timeout=args.future_timeout)
step_loss_sum += _extract_loss(result)
step_active += n_act
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
step_loss = (
step_loss_sum / step_active
if step_active > 0
else step_loss_sum
)
global_step += 1
total_loss += step_loss
self.state.global_step = global_step
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
avg_loss = total_loss / global_step
LOG.info(
f"[step {global_step}/{max_steps}] "
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
f"active={step_active} "
f"{elapsed / global_step:.2f}s/step"
)
self.log(
{
"loss": step_loss,
"learning_rate": self._optim_params["learning_rate"],
"epoch": self.state.epoch,
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
if self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
LOG.info(
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
"""Save a checkpoint on the remote service."""
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryTrainer uses remote API; compute_loss should not be called."
)