feat(remote): real LoRA training script with metric streaming

This commit is contained in:
2026-04-26 01:50:00 +00:00
parent ace187d2b2
commit aa85130ee8

235
packaging/remote/train.py Normal file
View File

@@ -0,0 +1,235 @@
#!/usr/bin/env python3
"""LoRA fine-tuning for causal LMs.
Designed to be invoked from llm-trainer dashboard. Streams structured progress
events via stdout so the dashboard can render real metrics instead of tailing logs.
Progress events are single-line JSON prefixed with [METRIC]:
[METRIC] {"type":"metric","step":12,"max_steps":120,"loss":1.83,...}
Status events use [STATUS] for non-metric milestones:
[STATUS] {"phase":"loading_model","detail":"meta-llama/Llama-3-8B"}
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
# Defer heavy imports so --help is fast and missing-deps surface as a clean error.
def _emit(kind, payload):
print(f"[{kind}] {json.dumps(payload)}", flush=True)
def status(phase, **extra):
_emit("STATUS", {"phase": phase, **extra})
def parse_args():
p = argparse.ArgumentParser(description="LoRA fine-tune a causal LM")
p.add_argument("--model", required=True, help="HF model id or local path")
p.add_argument("--dataset", required=True, help="Path to JSONL dataset")
p.add_argument("--output", required=True, help="Output directory")
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--batch-size", type=int, default=2)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--lora-r", type=int, default=16)
p.add_argument("--lora-alpha", type=int, default=32)
p.add_argument("--lora-dropout", type=float, default=0.05)
p.add_argument("--max-seq-length", type=int, default=2048)
p.add_argument("--gradient-accumulation", type=int, default=4)
p.add_argument("--quantize", choices=["4bit", "8bit", "none"], default="4bit")
p.add_argument("--warmup-ratio", type=float, default=0.03)
p.add_argument("--scheduler", default="cosine",
choices=["linear", "cosine", "constant", "cosine_with_restarts"])
p.add_argument("--save-steps", type=int, default=50)
p.add_argument("--eval-split", type=float, default=0.0,
help="Fraction held out for eval (0 disables)")
return p.parse_args()
def format_example(example):
"""Map common dataset schemas into a single text field."""
if "text" in example and example["text"]:
return {"text": example["text"]}
if "question" in example and "answer" in example:
return {"text": f"### Instruction:\n{example['question']}\n\n### Response:\n{example['answer']}"}
if "instruction" in example and "output" in example:
prefix = f"### Instruction:\n{example['instruction']}"
if example.get("input"):
prefix += f"\n\n### Input:\n{example['input']}"
return {"text": f"{prefix}\n\n### Response:\n{example['output']}"}
if "messages" in example and isinstance(example["messages"], list):
parts = []
for m in example["messages"]:
role = m.get("role", "user")
content = m.get("content", "")
parts.append(f"{role}: {content}")
return {"text": "\n".join(parts)}
raise ValueError(f"Unrecognized dataset schema. Keys: {list(example.keys())}")
def main():
args = parse_args()
status("starting", args=vars(args))
status("importing")
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments, TrainerCallback,
)
from trl import SFTTrainer
if not torch.cuda.is_available():
status("error", detail="No CUDA GPU detected — training will fail.")
sys.exit(2)
gpu_name = torch.cuda.get_device_name(0)
gpu_mem_gb = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)
status("gpu_detected", name=gpu_name, memory_gb=gpu_mem_gb)
# ── Dataset ───────────────────────────────────────────────
status("loading_dataset", path=args.dataset)
suffix = Path(args.dataset).suffix.lower()
if suffix not in (".jsonl", ".json"):
raise ValueError(f"Dataset must be .jsonl or .json, got {suffix}")
raw = load_dataset("json", data_files=args.dataset, split="train")
raw = raw.map(format_example, remove_columns=[c for c in raw.column_names if c != "text"])
eval_ds = None
if args.eval_split > 0 and len(raw) >= 10:
split = raw.train_test_split(test_size=args.eval_split, seed=42)
train_ds, eval_ds = split["train"], split["test"]
else:
train_ds = raw
status("dataset_loaded", train_examples=len(train_ds),
eval_examples=len(eval_ds) if eval_ds else 0)
# ── Tokenizer ─────────────────────────────────────────────
status("loading_tokenizer", model=args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ── Model ─────────────────────────────────────────────────
bnb_config = None
if args.quantize == "4bit":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
elif args.quantize == "8bit":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
status("loading_model", model=args.model, quantize=args.quantize)
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
if bnb_config:
model = prepare_model_for_kbit_training(model)
# ── LoRA ──────────────────────────────────────────────────
status("applying_lora", r=args.lora_r, alpha=args.lora_alpha)
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
status("lora_applied", trainable_params=trainable, total_params=total,
trainable_pct=round(100 * trainable / total, 3))
# ── Trainer ───────────────────────────────────────────────
class JSONCallback(TrainerCallback):
def __init__(self):
self.start = time.time()
def on_log(self, args_, state, control, logs=None, **kw):
if not logs:
return
elapsed = time.time() - self.start
ev = {
"step": state.global_step,
"max_steps": state.max_steps,
"epoch": logs.get("epoch"),
"loss": logs.get("loss"),
"eval_loss": logs.get("eval_loss"),
"learning_rate": logs.get("learning_rate"),
"grad_norm": logs.get("grad_norm"),
"elapsed_s": round(elapsed, 1),
"progress_pct": round(state.global_step / max(state.max_steps, 1) * 100, 1),
}
if state.global_step > 0 and elapsed > 0:
steps_per_s = state.global_step / elapsed
remaining = max(state.max_steps - state.global_step, 0)
ev["eta_s"] = round(remaining / steps_per_s, 1) if steps_per_s else None
ev["steps_per_s"] = round(steps_per_s, 3)
_emit("METRIC", ev)
training_args = TrainingArguments(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
learning_rate=args.lr,
bf16=True,
logging_steps=1,
save_steps=args.save_steps,
save_total_limit=2,
eval_strategy="steps" if eval_ds else "no",
eval_steps=args.save_steps if eval_ds else None,
optim="paged_adamw_8bit" if bnb_config else "adamw_torch",
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.scheduler,
report_to="none",
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
trainer = SFTTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
args=training_args,
max_seq_length=args.max_seq_length,
dataset_text_field="text",
callbacks=[JSONCallback()],
)
status("training")
trainer.train()
status("saving", path=args.output)
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
status("done", adapter_path=args.output)
if __name__ == "__main__":
try:
main()
except SystemExit:
raise
except Exception as e:
status("error", detail=str(e))
raise