train on remote compute using Tinker compatible APIs (#3614)
* train on remote compute using Tinker compatible APIs * chore: lint * fixes with latest hatchery changes * chore: lint
This commit is contained in:
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Hatchery/Tinker remote training integration for Axolotl.
|
||||
|
||||
Routes axolotl's preprocessed data to a remote training API (Tinker or
|
||||
Hatchery) instead of running forward/backward locally. The remote
|
||||
service handles model weights, LoRA adapters, and gradient updates.
|
||||
"""
|
||||
|
||||
from .args import HatcheryArgs, HatcheryConfig
|
||||
from .plugin import HatcheryPlugin
|
||||
|
||||
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
|
||||
|
||||
# Usage:
|
||||
# plugins:
|
||||
# - axolotl.integrations.hatchery.HatcheryPlugin
|
||||
#
|
||||
# hatchery:
|
||||
# backend: tinker # or "hatchery"
|
||||
# lora_rank: 32
|
||||
# loss_fn: cross_entropy # SFT
|
||||
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
|
||||
#
|
||||
# learning_rate: 1e-4 # top-level, not under hatchery:
|
||||
62
src/axolotl/integrations/hatchery/args.py
Normal file
62
src/axolotl/integrations/hatchery/args.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Pydantic config schema for the Hatchery integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HatcheryConfig(BaseModel):
|
||||
"""Nested config under `hatchery:` in the axolotl YAML.
|
||||
|
||||
Only contains hatchery-specific settings. Standard training params
|
||||
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
|
||||
gradient_accumulation_steps) are read from axolotl's top-level config.
|
||||
"""
|
||||
|
||||
# Backend & connection
|
||||
backend: Literal["tinker", "hatchery"] = "tinker"
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
# LoRA config sent to remote
|
||||
lora_rank: int = Field(32, ge=1, le=256)
|
||||
train_attn: bool = True
|
||||
train_mlp: bool = True
|
||||
train_unembed: bool = True
|
||||
|
||||
# Loss function
|
||||
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
|
||||
"cross_entropy"
|
||||
)
|
||||
loss_fn_config: Optional[dict[str, Any]] = None
|
||||
|
||||
# Pipelining: submit next batch before awaiting previous result
|
||||
pipeline: bool = True
|
||||
|
||||
# Sampling params (for RL flows)
|
||||
max_sample_tokens: int = 256
|
||||
sample_temperature: float = 1.0
|
||||
num_samples: int = 4
|
||||
|
||||
# Reward functions (for RL) — list of fully qualified names
|
||||
reward_funcs: Optional[list[str]] = None
|
||||
|
||||
# Checkpointing
|
||||
save_steps: Optional[int] = None
|
||||
save_name_prefix: str = "checkpoint"
|
||||
|
||||
# Timeout per future (seconds)
|
||||
future_timeout: float = 600.0
|
||||
|
||||
|
||||
class HatcheryArgs(BaseModel):
|
||||
"""Top-level mixin that adds the nested `hatchery:` field."""
|
||||
|
||||
hatchery: Optional[HatcheryConfig] = None
|
||||
160
src/axolotl/integrations/hatchery/data.py
Normal file
160
src/axolotl/integrations/hatchery/data.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
|
||||
|
||||
Both Tinker and Hatchery expect the client to apply the causal LM shift:
|
||||
|
||||
Original tokens: [t0, t1, t2, ..., t_{L-1}]
|
||||
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
|
||||
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
|
||||
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
|
||||
|
||||
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
|
||||
"""Serialize a tensor to the TensorData wire dict."""
|
||||
flat = t.detach().cpu().flatten()
|
||||
dtype_map = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
}
|
||||
return {
|
||||
"dtype": dtype_map.get(flat.dtype, "float32"),
|
||||
"shape": list(t.shape),
|
||||
"data": flat.tolist(),
|
||||
}
|
||||
|
||||
|
||||
def _make_datum(
|
||||
tokens: list[int],
|
||||
loss_fn_inputs: dict[str, torch.Tensor],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
|
||||
return {
|
||||
"model_input": {
|
||||
"chunks": [{"type": "encoded_text", "tokens": tokens}],
|
||||
},
|
||||
"loss_fn_inputs": {
|
||||
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def datums_to_tinker(datums: list[dict[str, Any]]):
|
||||
"""Wrap plain-dict datums into tinker.types.Datum objects.
|
||||
|
||||
Both the Tinker SDK and updated Hatchery client accept these.
|
||||
"""
|
||||
import tinker.types as tt
|
||||
|
||||
result = []
|
||||
for d in datums:
|
||||
tokens = d["model_input"]["chunks"][0]["tokens"]
|
||||
tinker_inputs = {}
|
||||
for key, wire in d["loss_fn_inputs"].items():
|
||||
tinker_inputs[key] = tt.TensorData(
|
||||
data=wire["data"],
|
||||
dtype=wire["dtype"],
|
||||
shape=wire["shape"],
|
||||
)
|
||||
result.append(
|
||||
tt.Datum(
|
||||
model_input=tt.ModelInput.from_ints(tokens),
|
||||
loss_fn_inputs=tinker_inputs,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def batch_to_datums_sft(
|
||||
input_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
|
||||
batch_size = input_ids.size(0)
|
||||
datums = []
|
||||
|
||||
for i in range(batch_size):
|
||||
ids = input_ids[i]
|
||||
lbl = labels[i]
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = int(attention_mask[i].sum().item())
|
||||
ids = ids[:seq_len]
|
||||
lbl = lbl[:seq_len]
|
||||
|
||||
model_tokens = ids[:-1].tolist()
|
||||
shifted_labels = lbl[1:]
|
||||
|
||||
target_tokens = shifted_labels.clone()
|
||||
weights = (shifted_labels != -100).float()
|
||||
target_tokens[target_tokens == -100] = 0
|
||||
|
||||
datums.append(
|
||||
_make_datum(
|
||||
model_tokens,
|
||||
{
|
||||
"target_tokens": target_tokens,
|
||||
"weights": weights,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return datums
|
||||
|
||||
|
||||
def batch_to_datums_rl(
|
||||
input_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
|
||||
batch_size = input_ids.size(0)
|
||||
datums = []
|
||||
|
||||
for i in range(batch_size):
|
||||
ids = input_ids[i]
|
||||
lbl = labels[i]
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = int(attention_mask[i].sum().item())
|
||||
else:
|
||||
seq_len = ids.size(0)
|
||||
ids = ids[:seq_len]
|
||||
lbl = lbl[:seq_len]
|
||||
lp = logprobs[i, :seq_len]
|
||||
adv = advantages[i, :seq_len]
|
||||
|
||||
model_tokens = ids[:-1].tolist()
|
||||
|
||||
target_tokens = lbl[1:].clone()
|
||||
target_tokens[target_tokens == -100] = 0
|
||||
|
||||
datums.append(
|
||||
_make_datum(
|
||||
model_tokens,
|
||||
{
|
||||
"target_tokens": target_tokens,
|
||||
"logprobs": lp[1:],
|
||||
"advantages": adv[1:],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return datums
|
||||
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
|
||||
|
||||
Creates a dataset with chat-formatted prompts that include
|
||||
a hidden gold answer tag for the reward function.
|
||||
|
||||
Run:
|
||||
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def extract_boxed(text: str) -> str:
|
||||
match = re.search(r"\\boxed\{", text)
|
||||
if not match:
|
||||
return ""
|
||||
start = match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return text[start : i - 1] if depth == 0 else ""
|
||||
|
||||
|
||||
def main():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
|
||||
|
||||
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
|
||||
level = os.environ.get("MATH_LEVEL", "Level 1")
|
||||
filtered_rows = [x for x in ds if x["level"] == level]
|
||||
print(f"{level} algebra: {len(filtered_rows)} problems")
|
||||
|
||||
rows = []
|
||||
for prob in filtered_rows:
|
||||
gold = extract_boxed(prob["solution"])
|
||||
if not gold:
|
||||
continue
|
||||
|
||||
# Format as chat prompt with hidden gold tag
|
||||
prompt = (
|
||||
f"Solve the following math problem. "
|
||||
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
|
||||
f"{prob['problem']}"
|
||||
f"<|gold|>{gold}<|/gold|>"
|
||||
)
|
||||
|
||||
# Tokenize the prompt
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"input_ids": prompt_ids,
|
||||
"labels": [-100] * len(prompt_ids),
|
||||
"attention_mask": [1] * len(prompt_ids),
|
||||
}
|
||||
)
|
||||
|
||||
out = Dataset.from_list(rows)
|
||||
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
|
||||
out.save_to_disk(out_dir)
|
||||
print(f"Saved {len(out)} examples to {out_dir}")
|
||||
if rows:
|
||||
print(
|
||||
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
|
||||
f"-{max(len(r['input_ids']) for r in rows)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
|
||||
#
|
||||
# Prep:
|
||||
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||
#
|
||||
# Run:
|
||||
# export TINKER_API_KEY="your-key"
|
||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
|
||||
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker
|
||||
lora_rank: 16
|
||||
loss_fn: importance_sampling
|
||||
max_sample_tokens: 2048
|
||||
sample_temperature: 0.7
|
||||
num_samples: 4
|
||||
pipeline: true
|
||||
save_steps: 5
|
||||
reward_funcs:
|
||||
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
|
||||
|
||||
datasets:
|
||||
- path: ./data/math_rl_level1
|
||||
ds_type: arrow
|
||||
type: completion
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
learning_rate: 5.0e-5
|
||||
optimizer: adamw_torch
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
weight_decay: 0.01
|
||||
max_grad_norm: 1.0
|
||||
|
||||
max_steps: 10
|
||||
num_epochs: 1
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
logging_steps: 1
|
||||
|
||||
output_dir: ./outputs/tinker-rl-math
|
||||
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
|
||||
#
|
||||
# Usage:
|
||||
# export TINKER_API_KEY="your-key"
|
||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
|
||||
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker
|
||||
lora_rank: 16
|
||||
loss_fn: cross_entropy
|
||||
pipeline: true
|
||||
save_steps: 10
|
||||
|
||||
datasets:
|
||||
- path: TeichAI/kimi-k2-thinking-1000x
|
||||
split: train[:50]
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
|
||||
chat_template: qwen3
|
||||
sequence_len: 2048
|
||||
|
||||
learning_rate: 3.0e-4
|
||||
optimizer: adamw_torch
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
weight_decay: 0.01
|
||||
max_grad_norm: 1.0
|
||||
|
||||
num_epochs: 1
|
||||
max_steps: 20
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
logging_steps: 1
|
||||
|
||||
output_dir: ./outputs/tinker-sft
|
||||
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoConfig, PreTrainedModel, Trainer
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class HatcheryPlugin(BasePlugin):
|
||||
"""Plugin that replaces local training with remote API calls.
|
||||
|
||||
Activated by adding to the axolotl YAML:
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker # or "hatchery"
|
||||
lora_rank: 32
|
||||
loss_fn: cross_entropy
|
||||
# ... see HatcheryConfig for full options
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
return "axolotl.integrations.hatchery.args.HatcheryArgs"
|
||||
|
||||
def register(self, cfg: dict):
|
||||
"""Auto-set config values needed for remote training."""
|
||||
if cfg.get("remove_unused_columns") is None:
|
||||
cfg["remove_unused_columns"] = False
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
"""Replace model loading with a tiny stub."""
|
||||
hcfg = cfg.hatchery or {}
|
||||
backend = (
|
||||
hcfg.get("backend", "tinker")
|
||||
if isinstance(hcfg, dict)
|
||||
else getattr(hcfg, "backend", "tinker")
|
||||
)
|
||||
LOG.info(
|
||||
f"Hatchery plugin active: training dispatched to remote "
|
||||
f"{backend} API. Skipping local model weight loading."
|
||||
)
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
|
||||
def _stub_build_model(loader_self) -> bool:
|
||||
base_model = loader_self.cfg.base_model
|
||||
LOG.info(f"Skipping model weight loading for: {base_model}")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
|
||||
)
|
||||
|
||||
class _Stub(PreTrainedModel):
|
||||
config_class = type(config)
|
||||
_no_split_modules: list[str] = []
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
vocab_size = getattr(cfg, "vocab_size", 32000)
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
pass
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return None
|
||||
|
||||
loader_self.model = _Stub(config)
|
||||
return True
|
||||
|
||||
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||
"""Return the appropriate remote trainer class."""
|
||||
hcfg = cfg.hatchery
|
||||
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
|
||||
|
||||
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
|
||||
from .rl_trainer import HatcheryRLTrainer
|
||||
|
||||
return HatcheryRLTrainer
|
||||
|
||||
from .trainer import HatcheryTrainer
|
||||
|
||||
return HatcheryTrainer
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
model._hatchery_remote = True
|
||||
|
||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
LOG.info(
|
||||
"Hatchery: skipping local model save (weights are on remote API). "
|
||||
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
|
||||
)
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Inject hatchery config + axolotl training params into the trainer."""
|
||||
from .args import HatcheryConfig
|
||||
from .rl_trainer import HatcheryRLTrainer
|
||||
from .trainer import HatcheryTrainer
|
||||
|
||||
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
|
||||
return
|
||||
|
||||
hcfg = cfg.hatchery
|
||||
if isinstance(hcfg, dict):
|
||||
hatchery_config = HatcheryConfig(**hcfg)
|
||||
elif hcfg is None:
|
||||
hatchery_config = HatcheryConfig()
|
||||
else:
|
||||
hatchery_config = hcfg
|
||||
|
||||
trainer.hatchery_args = hatchery_config
|
||||
trainer._base_model_name = cfg.base_model
|
||||
|
||||
# Pull standard training params from axolotl config so they
|
||||
# don't need to be duplicated under hatchery:
|
||||
trainer._optim_params = {
|
||||
"learning_rate": cfg.learning_rate
|
||||
if cfg.learning_rate is not None
|
||||
else 1e-4,
|
||||
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
|
||||
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
|
||||
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
|
||||
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
"grad_clip_norm": cfg.max_grad_norm
|
||||
if cfg.max_grad_norm is not None
|
||||
else 0.0,
|
||||
}
|
||||
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Math reward function for hendrycks_math GRPO training.
|
||||
|
||||
Uses math_verify for robust answer comparison. Falls back to
|
||||
exact string match of \\boxed{} content only when math_verify
|
||||
is unavailable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_boxed(text: str) -> str | None:
|
||||
"""Extract \\boxed{...} answer handling nested braces."""
|
||||
match = re.search(r"\\boxed\{", text)
|
||||
if not match:
|
||||
return None
|
||||
start = match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return text[start : i - 1] if depth == 0 else None
|
||||
|
||||
|
||||
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
|
||||
"""Score completions by checking if \\boxed{} answer matches the gold answer.
|
||||
|
||||
The gold answer is extracted from the prompt (appended as a hidden
|
||||
tag by the dataset preprocessing). Format:
|
||||
... <|gold|>ANSWER<|/gold|>
|
||||
"""
|
||||
rewards = []
|
||||
for prompt, completion in zip(prompts, completions, strict=True):
|
||||
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
|
||||
if not gold_match:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
gold_answer = gold_match.group(1).strip()
|
||||
pred_answer = extract_boxed(completion)
|
||||
|
||||
if pred_answer is None:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
verified = None
|
||||
try:
|
||||
from math_verify import parse, verify
|
||||
|
||||
gold_parsed = parse(gold_answer)
|
||||
pred_parsed = parse(pred_answer)
|
||||
verified = verify(gold_parsed, pred_parsed)
|
||||
except Exception:
|
||||
LOG.debug(
|
||||
"math_verify unavailable or failed, using string fallback",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if verified is not None:
|
||||
rewards.append(1.0 if verified else 0.0)
|
||||
elif pred_answer.strip() == gold_answer.strip():
|
||||
rewards.append(1.0)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
|
||||
|
||||
Full RL loop per step:
|
||||
1. Extract prompts from dataset batch
|
||||
2. Sample N completions per prompt via remote SamplingClient
|
||||
3. Score completions with local reward functions
|
||||
4. Compute GRPO-style advantages (per-group normalization)
|
||||
5. Send (prompt+completion, logprobs, advantages) as forward_backward
|
||||
6. Optimizer step
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers.trainer_utils import TrainOutput
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import HatcheryConfig
|
||||
from .data import batch_to_datums_rl, datums_to_tinker
|
||||
from .trainer import _create_training_client
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _load_reward_func(fqn: str) -> Callable:
|
||||
"""Load a reward function from a fully qualified name like 'module.func'."""
|
||||
module_path = ".".join(fqn.split(".")[:-1])
|
||||
func_name = fqn.split(".")[-1]
|
||||
mod = importlib.import_module(module_path)
|
||||
func = getattr(mod, func_name)
|
||||
if len(inspect.signature(func).parameters) < 2:
|
||||
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
|
||||
return func
|
||||
|
||||
|
||||
class HatcheryRLTrainer(AxolotlTrainer):
|
||||
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
|
||||
|
||||
hatchery_args: Optional[HatcheryConfig]
|
||||
_base_model_name: Optional[str]
|
||||
_training_client: Any
|
||||
_reward_functions: list[Callable]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hatchery_args = None
|
||||
self._base_model_name = None
|
||||
self._training_client = None
|
||||
self._reward_functions = []
|
||||
|
||||
def _ensure_reward_functions(self):
|
||||
if self._reward_functions:
|
||||
return
|
||||
args = self.hatchery_args
|
||||
if not args or not args.reward_funcs:
|
||||
raise ValueError(
|
||||
"No reward functions configured. Set hatchery.reward_funcs "
|
||||
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
|
||||
)
|
||||
for fqn in args.reward_funcs:
|
||||
self._reward_functions.append(_load_reward_func(fqn))
|
||||
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
|
||||
|
||||
def _get_training_client(self):
|
||||
if self._training_client is not None:
|
||||
return self._training_client
|
||||
|
||||
self._training_client = _create_training_client(
|
||||
self.hatchery_args, self._base_model_name
|
||||
)
|
||||
LOG.info(
|
||||
f"Remote RL session created: backend={self.hatchery_args.backend}, "
|
||||
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
|
||||
)
|
||||
return self._training_client
|
||||
|
||||
def _sample_completions(self, prompt_ids_list: list[list[int]]):
|
||||
"""Sample completions for prompts via remote API."""
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
results = []
|
||||
|
||||
sc = tc.save_weights_and_get_sampling_client()
|
||||
|
||||
for prompt_ids in prompt_ids_list:
|
||||
if hasattr(sc, "sampling_session_id"):
|
||||
sample_result = sc.sample(
|
||||
prompt_ids,
|
||||
max_tokens=args.max_sample_tokens,
|
||||
temperature=args.sample_temperature,
|
||||
n=args.num_samples,
|
||||
).result(timeout=args.future_timeout)
|
||||
else:
|
||||
mi = tt.ModelInput.from_ints(prompt_ids)
|
||||
sp = tt.SamplingParams(
|
||||
max_tokens=args.max_sample_tokens,
|
||||
temperature=args.sample_temperature,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
)
|
||||
sample_result = sc.sample(
|
||||
prompt=mi,
|
||||
num_samples=args.num_samples,
|
||||
sampling_params=sp,
|
||||
).result(timeout=args.future_timeout)
|
||||
|
||||
sequences = (
|
||||
sample_result.sequences
|
||||
if hasattr(sample_result, "sequences")
|
||||
else sample_result.get("sequences", [])
|
||||
)
|
||||
for seq in sequences:
|
||||
tokens = (
|
||||
list(seq.tokens)
|
||||
if hasattr(seq, "tokens")
|
||||
else seq.get("tokens", [])
|
||||
)
|
||||
logprobs = (
|
||||
list(seq.logprobs)
|
||||
if hasattr(seq, "logprobs") and seq.logprobs
|
||||
else seq.get("logprobs", [])
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": list(prompt_ids) + tokens,
|
||||
"completion_tokens": tokens,
|
||||
"logprobs": logprobs,
|
||||
"prompt_len": len(prompt_ids),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _compute_rewards(
|
||||
self, prompts: list[str], completions: list[str]
|
||||
) -> list[float]:
|
||||
total_rewards = [0.0] * len(completions)
|
||||
for reward_fn in self._reward_functions:
|
||||
rewards = reward_fn(prompts, completions)
|
||||
for i, r in enumerate(rewards):
|
||||
total_rewards[i] += r
|
||||
return total_rewards
|
||||
|
||||
@staticmethod
|
||||
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
|
||||
advantages = []
|
||||
for i in range(0, len(rewards), group_size):
|
||||
group = rewards[i : i + group_size]
|
||||
mean = sum(group) / len(group)
|
||||
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
|
||||
std = var**0.5 if var > 1e-8 else 1.0
|
||||
advantages.extend([(r - mean) / std for r in group])
|
||||
return advantages
|
||||
|
||||
def _do_optim_step(self):
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
trial: Any = None,
|
||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> TrainOutput:
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError("hatchery_args not configured")
|
||||
|
||||
self._ensure_reward_functions()
|
||||
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_train_epochs = int(self.args.num_train_epochs)
|
||||
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
|
||||
|
||||
LOG.info(
|
||||
f"Remote RL training: max_steps={max_steps}, "
|
||||
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
|
||||
)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = True
|
||||
self.state.is_world_process_zero = True
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
self.args,
|
||||
self.state,
|
||||
self.control, # type: ignore[has-type]
|
||||
)
|
||||
|
||||
tokenizer = self.processing_class
|
||||
global_step = 0
|
||||
total_loss = 0.0
|
||||
total_reward = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
for batch in train_dataloader:
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
prompt_ids_batch = batch["input_ids"]
|
||||
# Full prompt text (with gold tag) for reward scoring
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_ids_batch, skip_special_tokens=False
|
||||
)
|
||||
|
||||
# Strip <|gold|>...<|/gold|> from token ids before
|
||||
# sending to the model for sampling — the gold answer
|
||||
# must only be visible to the local reward function.
|
||||
sampling_prompts = []
|
||||
for prompt_text in prompt_texts:
|
||||
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
|
||||
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
|
||||
sampling_prompts.append(clean_ids)
|
||||
|
||||
# 1. Sample completions (without gold answer)
|
||||
t0 = time.time()
|
||||
samples = self._sample_completions(sampling_prompts)
|
||||
t_sample = time.time() - t0
|
||||
|
||||
if not samples:
|
||||
LOG.warning("No samples generated, skipping step")
|
||||
continue
|
||||
LOG.info(
|
||||
f"Sampled {len(samples)} completions, "
|
||||
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
|
||||
)
|
||||
|
||||
# 2. Decode and score
|
||||
completion_texts = [
|
||||
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
|
||||
for s in samples
|
||||
]
|
||||
sample_prompts = []
|
||||
for prompt_text in prompt_texts:
|
||||
sample_prompts.extend([prompt_text] * args.num_samples)
|
||||
|
||||
rewards = self._compute_rewards(sample_prompts, completion_texts)
|
||||
|
||||
# 3. GRPO advantages
|
||||
advantages_list = self._compute_advantages(
|
||||
rewards, group_size=args.num_samples
|
||||
)
|
||||
|
||||
# 4. Build training data
|
||||
all_datums = []
|
||||
for i, sample in enumerate(samples):
|
||||
full_tokens = sample["tokens"]
|
||||
prompt_len = sample["prompt_len"]
|
||||
seq_len = len(full_tokens)
|
||||
|
||||
input_ids = torch.tensor([full_tokens], dtype=torch.long)
|
||||
labels = torch.full((1, seq_len), -100, dtype=torch.long)
|
||||
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
|
||||
|
||||
logprobs_t = torch.zeros(1, seq_len)
|
||||
if sample["logprobs"]:
|
||||
lp = sample["logprobs"][: seq_len - prompt_len]
|
||||
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
|
||||
lp
|
||||
)
|
||||
|
||||
adv_t = torch.zeros(1, seq_len)
|
||||
adv_t[0, prompt_len:] = advantages_list[i]
|
||||
|
||||
all_datums.extend(
|
||||
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
|
||||
)
|
||||
|
||||
# 5. Forward backward (one datum at a time for memory) + optim
|
||||
t0 = time.time()
|
||||
tc = self._get_training_client()
|
||||
step_loss = 0.0
|
||||
for datum in all_datums:
|
||||
fb_future = tc.forward_backward(
|
||||
datums_to_tinker([datum]),
|
||||
loss_fn=args.loss_fn,
|
||||
loss_fn_config=args.loss_fn_config,
|
||||
)
|
||||
fb_result = fb_future.result(timeout=args.future_timeout)
|
||||
if hasattr(fb_result, "metrics"):
|
||||
step_loss += float(
|
||||
(fb_result.metrics or {}).get("loss:sum", 0.0)
|
||||
)
|
||||
elif isinstance(fb_result, dict):
|
||||
step_loss += float(
|
||||
fb_result.get("metrics", {}).get("loss:sum", 0.0)
|
||||
)
|
||||
optim_future = self._do_optim_step()
|
||||
if not args.pipeline:
|
||||
optim_future.result(timeout=args.future_timeout)
|
||||
t_train = time.time() - t0
|
||||
|
||||
mean_reward = sum(rewards) / len(rewards)
|
||||
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
|
||||
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
|
||||
global_step += 1
|
||||
total_loss += step_loss
|
||||
total_reward += mean_reward
|
||||
self.state.global_step = global_step
|
||||
|
||||
log_interval = self.args.logging_steps or 1
|
||||
if global_step % log_interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
LOG.info(
|
||||
f"[step {global_step}/{max_steps}] "
|
||||
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
|
||||
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
|
||||
f"sample={t_sample:.1f}s train={t_train:.1f}s "
|
||||
f"{elapsed / global_step:.1f}s/step"
|
||||
)
|
||||
self.log(
|
||||
{
|
||||
"loss": step_loss,
|
||||
"reward": mean_reward,
|
||||
"accuracy": accuracy,
|
||||
"mean_abs_advantage": mean_adv,
|
||||
"learning_rate": self._optim_params["learning_rate"],
|
||||
}
|
||||
)
|
||||
|
||||
if args.save_steps and global_step % args.save_steps == 0:
|
||||
self._save_remote_checkpoint(global_step)
|
||||
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if global_step > 0:
|
||||
self._save_remote_checkpoint(global_step, name="final")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / max(global_step, 1)
|
||||
avg_reward = total_reward / max(global_step, 1)
|
||||
|
||||
LOG.info(
|
||||
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
|
||||
f"avg_reward={avg_reward:.4f}"
|
||||
)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
return TrainOutput(
|
||||
global_step=global_step,
|
||||
training_loss=avg_loss,
|
||||
metrics={
|
||||
"train_loss": avg_loss,
|
||||
"train_reward": avg_reward,
|
||||
"train_runtime": elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||
try:
|
||||
future = tc.save_state(ckpt_name)
|
||||
future.result(timeout=args.future_timeout)
|
||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||
except Exception:
|
||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||
if name == "final":
|
||||
raise
|
||||
|
||||
def save_model(self, output_dir=None, _internal_call=False):
|
||||
self._save_remote_checkpoint(
|
||||
step=self.state.global_step,
|
||||
name=output_dir or "hf-save",
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
|
||||
)
|
||||
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Remote trainer that dispatches to Tinker or Hatchery API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from transformers.trainer_utils import TrainOutput
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import HatcheryConfig
|
||||
from .data import batch_to_datums_sft, datums_to_tinker
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_loss(result) -> float:
|
||||
"""Extract loss:sum from a forward_backward result.
|
||||
|
||||
Tinker's cross_entropy (and other losses) return the SUM of per-token
|
||||
losses, not the mean. This is by design — it lets users control
|
||||
normalization via the weights tensor. The trainer logs this raw sum;
|
||||
users who want per-token loss should divide by number of active tokens.
|
||||
"""
|
||||
if hasattr(result, "metrics"):
|
||||
metrics = result.metrics or {}
|
||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||
if isinstance(result, dict):
|
||||
metrics = result.get("metrics", {})
|
||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||
return 0.0
|
||||
|
||||
|
||||
def _create_training_client(args: HatcheryConfig, base_model: str):
|
||||
"""Create a training client for either Tinker or Hatchery backend."""
|
||||
if args.backend == "tinker":
|
||||
import tinker
|
||||
|
||||
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Tinker API key required. Set `hatchery.api_key` in config "
|
||||
"or TINKER_API_KEY env var."
|
||||
)
|
||||
os.environ["TINKER_API_KEY"] = api_key
|
||||
|
||||
service = tinker.ServiceClient(project_id=args.project_id)
|
||||
return service.create_lora_training_client(
|
||||
base_model=base_model,
|
||||
rank=args.lora_rank,
|
||||
train_mlp=args.train_mlp,
|
||||
train_attn=args.train_attn,
|
||||
train_unembed=args.train_unembed,
|
||||
)
|
||||
|
||||
from hatchery.core.client import HatcheryClient
|
||||
|
||||
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
|
||||
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
|
||||
|
||||
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
|
||||
return client.create_lora_training_client(
|
||||
base_model=base_model,
|
||||
rank=args.lora_rank,
|
||||
train_attn=args.train_attn,
|
||||
train_mlp=args.train_mlp,
|
||||
train_unembed=args.train_unembed,
|
||||
)
|
||||
|
||||
|
||||
class HatcheryTrainer(AxolotlTrainer):
|
||||
"""Trainer that sends preprocessed batches to a remote training API.
|
||||
|
||||
Replaces local forward/backward with remote API calls to Tinker or
|
||||
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
|
||||
chat templates, packing, etc.) but offloads compute to remote GPUs.
|
||||
"""
|
||||
|
||||
hatchery_args: Optional[HatcheryConfig]
|
||||
_base_model_name: Optional[str]
|
||||
_training_client: Any
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hatchery_args = None
|
||||
self._base_model_name = None
|
||||
self._training_client = None
|
||||
|
||||
def _get_training_client(self):
|
||||
"""Lazily create the remote training session."""
|
||||
if self._training_client is not None:
|
||||
return self._training_client
|
||||
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError(
|
||||
"HatcheryTrainer.hatchery_args not set. "
|
||||
"Ensure the HatcheryPlugin is registered."
|
||||
)
|
||||
|
||||
base_model = self._base_model_name
|
||||
if not base_model:
|
||||
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
|
||||
|
||||
self._training_client = _create_training_client(args, base_model)
|
||||
|
||||
LOG.info(
|
||||
f"Remote training session created: backend={args.backend}, "
|
||||
f"model={base_model}, rank={args.lora_rank}"
|
||||
)
|
||||
return self._training_client
|
||||
|
||||
def _send_batch(self, batch: dict[str, torch.Tensor]):
|
||||
"""Convert batch to datums and send forward_backward to remote.
|
||||
|
||||
Returns (future, n_active_tokens) where n_active_tokens counts
|
||||
the completion tokens in this batch (for loss normalization).
|
||||
"""
|
||||
input_ids = batch["input_ids"]
|
||||
labels = batch["labels"]
|
||||
attention_mask = batch.get("attention_mask")
|
||||
|
||||
n_active = int((labels[:, 1:] != -100).sum().item())
|
||||
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
|
||||
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
send_datums = datums_to_tinker(datums)
|
||||
|
||||
future = tc.forward_backward(
|
||||
send_datums,
|
||||
loss_fn=args.loss_fn,
|
||||
loss_fn_config=args.loss_fn_config,
|
||||
)
|
||||
return future, n_active
|
||||
|
||||
def _do_optim_step(self):
|
||||
"""Send optimizer step to remote using axolotl's training params."""
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
trial: Any = None,
|
||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> TrainOutput:
|
||||
"""Main training loop — sends batches to remote API."""
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError("hatchery_args not configured")
|
||||
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_batches = len(train_dataloader)
|
||||
|
||||
grad_accum = self.args.gradient_accumulation_steps
|
||||
num_train_epochs = int(self.args.num_train_epochs)
|
||||
steps_per_epoch = max(num_batches // grad_accum, 1)
|
||||
max_steps = (
|
||||
self.args.max_steps
|
||||
if self.args.max_steps > 0
|
||||
else steps_per_epoch * num_train_epochs
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Remote training: {num_batches} batches/epoch, "
|
||||
f"{grad_accum} grad_accum, {max_steps} max steps, "
|
||||
f"{num_train_epochs} epochs"
|
||||
)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = True
|
||||
self.state.is_world_process_zero = True
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
self.args,
|
||||
self.state,
|
||||
self.control, # type: ignore[has-type]
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
total_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
pending_fb_futures = []
|
||||
accum_count = 0
|
||||
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
fb_future, n_active = self._send_batch(batch)
|
||||
pending_fb_futures.append((fb_future, n_active))
|
||||
accum_count += 1
|
||||
|
||||
if accum_count >= grad_accum:
|
||||
step_loss_sum = 0.0
|
||||
step_active = 0
|
||||
for fut, n_act in pending_fb_futures:
|
||||
result = fut.result(timeout=args.future_timeout)
|
||||
step_loss_sum += _extract_loss(result)
|
||||
step_active += n_act
|
||||
|
||||
optim_future = self._do_optim_step()
|
||||
if not args.pipeline:
|
||||
optim_future.result(timeout=args.future_timeout)
|
||||
|
||||
step_loss = (
|
||||
step_loss_sum / step_active
|
||||
if step_active > 0
|
||||
else step_loss_sum
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
total_loss += step_loss
|
||||
self.state.global_step = global_step
|
||||
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
|
||||
|
||||
log_interval = self.args.logging_steps or 1
|
||||
if global_step % log_interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / global_step
|
||||
LOG.info(
|
||||
f"[step {global_step}/{max_steps}] "
|
||||
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
|
||||
f"active={step_active} "
|
||||
f"{elapsed / global_step:.2f}s/step"
|
||||
)
|
||||
self.log(
|
||||
{
|
||||
"loss": step_loss,
|
||||
"learning_rate": self._optim_params["learning_rate"],
|
||||
"epoch": self.state.epoch,
|
||||
}
|
||||
)
|
||||
|
||||
if args.save_steps and global_step % args.save_steps == 0:
|
||||
self._save_remote_checkpoint(global_step)
|
||||
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
pending_fb_futures = []
|
||||
accum_count = 0
|
||||
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if global_step > 0:
|
||||
self._save_remote_checkpoint(global_step, name="final")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / max(global_step, 1)
|
||||
|
||||
LOG.info(
|
||||
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
|
||||
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
|
||||
)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
return TrainOutput(
|
||||
global_step=global_step,
|
||||
training_loss=avg_loss,
|
||||
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
|
||||
)
|
||||
|
||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||
"""Save a checkpoint on the remote service."""
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||
try:
|
||||
future = tc.save_state(ckpt_name)
|
||||
future.result(timeout=args.future_timeout)
|
||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||
except Exception:
|
||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||
if name == "final":
|
||||
raise
|
||||
|
||||
def save_model(self, output_dir=None, _internal_call=False):
|
||||
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
|
||||
self._save_remote_checkpoint(
|
||||
step=self.state.global_step,
|
||||
name=output_dir or "hf-save",
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"HatcheryTrainer uses remote API; compute_loss should not be called."
|
||||
)
|
||||
Reference in New Issue
Block a user