feat(remote): real LoRA training script with metric streaming
This commit is contained in:
235
packaging/remote/train.py
Normal file
235
packaging/remote/train.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python3
|
||||
"""LoRA fine-tuning for causal LMs.
|
||||
|
||||
Designed to be invoked from llm-trainer dashboard. Streams structured progress
|
||||
events via stdout so the dashboard can render real metrics instead of tailing logs.
|
||||
|
||||
Progress events are single-line JSON prefixed with [METRIC]:
|
||||
[METRIC] {"type":"metric","step":12,"max_steps":120,"loss":1.83,...}
|
||||
|
||||
Status events use [STATUS] for non-metric milestones:
|
||||
[STATUS] {"phase":"loading_model","detail":"meta-llama/Llama-3-8B"}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Defer heavy imports so --help is fast and missing-deps surface as a clean error.
|
||||
|
||||
def _emit(kind, payload):
|
||||
print(f"[{kind}] {json.dumps(payload)}", flush=True)
|
||||
|
||||
|
||||
def status(phase, **extra):
|
||||
_emit("STATUS", {"phase": phase, **extra})
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description="LoRA fine-tune a causal LM")
|
||||
p.add_argument("--model", required=True, help="HF model id or local path")
|
||||
p.add_argument("--dataset", required=True, help="Path to JSONL dataset")
|
||||
p.add_argument("--output", required=True, help="Output directory")
|
||||
p.add_argument("--epochs", type=int, default=3)
|
||||
p.add_argument("--batch-size", type=int, default=2)
|
||||
p.add_argument("--lr", type=float, default=2e-4)
|
||||
p.add_argument("--lora-r", type=int, default=16)
|
||||
p.add_argument("--lora-alpha", type=int, default=32)
|
||||
p.add_argument("--lora-dropout", type=float, default=0.05)
|
||||
p.add_argument("--max-seq-length", type=int, default=2048)
|
||||
p.add_argument("--gradient-accumulation", type=int, default=4)
|
||||
p.add_argument("--quantize", choices=["4bit", "8bit", "none"], default="4bit")
|
||||
p.add_argument("--warmup-ratio", type=float, default=0.03)
|
||||
p.add_argument("--scheduler", default="cosine",
|
||||
choices=["linear", "cosine", "constant", "cosine_with_restarts"])
|
||||
p.add_argument("--save-steps", type=int, default=50)
|
||||
p.add_argument("--eval-split", type=float, default=0.0,
|
||||
help="Fraction held out for eval (0 disables)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def format_example(example):
|
||||
"""Map common dataset schemas into a single text field."""
|
||||
if "text" in example and example["text"]:
|
||||
return {"text": example["text"]}
|
||||
if "question" in example and "answer" in example:
|
||||
return {"text": f"### Instruction:\n{example['question']}\n\n### Response:\n{example['answer']}"}
|
||||
if "instruction" in example and "output" in example:
|
||||
prefix = f"### Instruction:\n{example['instruction']}"
|
||||
if example.get("input"):
|
||||
prefix += f"\n\n### Input:\n{example['input']}"
|
||||
return {"text": f"{prefix}\n\n### Response:\n{example['output']}"}
|
||||
if "messages" in example and isinstance(example["messages"], list):
|
||||
parts = []
|
||||
for m in example["messages"]:
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "")
|
||||
parts.append(f"{role}: {content}")
|
||||
return {"text": "\n".join(parts)}
|
||||
raise ValueError(f"Unrecognized dataset schema. Keys: {list(example.keys())}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
status("starting", args=vars(args))
|
||||
|
||||
status("importing")
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from transformers import (
|
||||
AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig, TrainingArguments, TrainerCallback,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
status("error", detail="No CUDA GPU detected — training will fail.")
|
||||
sys.exit(2)
|
||||
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
gpu_mem_gb = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)
|
||||
status("gpu_detected", name=gpu_name, memory_gb=gpu_mem_gb)
|
||||
|
||||
# ── Dataset ───────────────────────────────────────────────
|
||||
status("loading_dataset", path=args.dataset)
|
||||
suffix = Path(args.dataset).suffix.lower()
|
||||
if suffix not in (".jsonl", ".json"):
|
||||
raise ValueError(f"Dataset must be .jsonl or .json, got {suffix}")
|
||||
raw = load_dataset("json", data_files=args.dataset, split="train")
|
||||
raw = raw.map(format_example, remove_columns=[c for c in raw.column_names if c != "text"])
|
||||
eval_ds = None
|
||||
if args.eval_split > 0 and len(raw) >= 10:
|
||||
split = raw.train_test_split(test_size=args.eval_split, seed=42)
|
||||
train_ds, eval_ds = split["train"], split["test"]
|
||||
else:
|
||||
train_ds = raw
|
||||
status("dataset_loaded", train_examples=len(train_ds),
|
||||
eval_examples=len(eval_ds) if eval_ds else 0)
|
||||
|
||||
# ── Tokenizer ─────────────────────────────────────────────
|
||||
status("loading_tokenizer", model=args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# ── Model ─────────────────────────────────────────────────
|
||||
bnb_config = None
|
||||
if args.quantize == "4bit":
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
elif args.quantize == "8bit":
|
||||
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
status("loading_model", model=args.model, quantize=args.quantize)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
if bnb_config:
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# ── LoRA ──────────────────────────────────────────────────
|
||||
status("applying_lora", r=args.lora_r, alpha=args.lora_alpha)
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total = sum(p.numel() for p in model.parameters())
|
||||
status("lora_applied", trainable_params=trainable, total_params=total,
|
||||
trainable_pct=round(100 * trainable / total, 3))
|
||||
|
||||
# ── Trainer ───────────────────────────────────────────────
|
||||
class JSONCallback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
|
||||
def on_log(self, args_, state, control, logs=None, **kw):
|
||||
if not logs:
|
||||
return
|
||||
elapsed = time.time() - self.start
|
||||
ev = {
|
||||
"step": state.global_step,
|
||||
"max_steps": state.max_steps,
|
||||
"epoch": logs.get("epoch"),
|
||||
"loss": logs.get("loss"),
|
||||
"eval_loss": logs.get("eval_loss"),
|
||||
"learning_rate": logs.get("learning_rate"),
|
||||
"grad_norm": logs.get("grad_norm"),
|
||||
"elapsed_s": round(elapsed, 1),
|
||||
"progress_pct": round(state.global_step / max(state.max_steps, 1) * 100, 1),
|
||||
}
|
||||
if state.global_step > 0 and elapsed > 0:
|
||||
steps_per_s = state.global_step / elapsed
|
||||
remaining = max(state.max_steps - state.global_step, 0)
|
||||
ev["eta_s"] = round(remaining / steps_per_s, 1) if steps_per_s else None
|
||||
ev["steps_per_s"] = round(steps_per_s, 3)
|
||||
_emit("METRIC", ev)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation,
|
||||
learning_rate=args.lr,
|
||||
bf16=True,
|
||||
logging_steps=1,
|
||||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
eval_strategy="steps" if eval_ds else "no",
|
||||
eval_steps=args.save_steps if eval_ds else None,
|
||||
optim="paged_adamw_8bit" if bnb_config else "adamw_torch",
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
lr_scheduler_type=args.scheduler,
|
||||
report_to="none",
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=eval_ds,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
max_seq_length=args.max_seq_length,
|
||||
dataset_text_field="text",
|
||||
callbacks=[JSONCallback()],
|
||||
)
|
||||
|
||||
status("training")
|
||||
trainer.train()
|
||||
|
||||
status("saving", path=args.output)
|
||||
model.save_pretrained(args.output)
|
||||
tokenizer.save_pretrained(args.output)
|
||||
|
||||
status("done", adapter_path=args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
status("error", detail=str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user