235 lines
9.4 KiB
Python
235 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
|
"""LoRA fine-tuning for causal LMs.
|
|
|
|
Designed to be invoked from llm-trainer dashboard. Streams structured progress
|
|
events via stdout so the dashboard can render real metrics instead of tailing logs.
|
|
|
|
Progress events are single-line JSON prefixed with [METRIC]:
|
|
[METRIC] {"type":"metric","step":12,"max_steps":120,"loss":1.83,...}
|
|
|
|
Status events use [STATUS] for non-metric milestones:
|
|
[STATUS] {"phase":"loading_model","detail":"meta-llama/Llama-3-8B"}
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Defer heavy imports so --help is fast and missing-deps surface as a clean error.
|
|
|
|
def _emit(kind, payload):
|
|
print(f"[{kind}] {json.dumps(payload)}", flush=True)
|
|
|
|
|
|
def status(phase, **extra):
|
|
_emit("STATUS", {"phase": phase, **extra})
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="LoRA fine-tune a causal LM")
|
|
p.add_argument("--model", required=True, help="HF model id or local path")
|
|
p.add_argument("--dataset", required=True, help="Path to JSONL dataset")
|
|
p.add_argument("--output", required=True, help="Output directory")
|
|
p.add_argument("--epochs", type=int, default=3)
|
|
p.add_argument("--batch-size", type=int, default=2)
|
|
p.add_argument("--lr", type=float, default=2e-4)
|
|
p.add_argument("--lora-r", type=int, default=16)
|
|
p.add_argument("--lora-alpha", type=int, default=32)
|
|
p.add_argument("--lora-dropout", type=float, default=0.05)
|
|
p.add_argument("--max-seq-length", type=int, default=2048)
|
|
p.add_argument("--gradient-accumulation", type=int, default=4)
|
|
p.add_argument("--quantize", choices=["4bit", "8bit", "none"], default="4bit")
|
|
p.add_argument("--warmup-ratio", type=float, default=0.03)
|
|
p.add_argument("--scheduler", default="cosine",
|
|
choices=["linear", "cosine", "constant", "cosine_with_restarts"])
|
|
p.add_argument("--save-steps", type=int, default=50)
|
|
p.add_argument("--eval-split", type=float, default=0.0,
|
|
help="Fraction held out for eval (0 disables)")
|
|
return p.parse_args()
|
|
|
|
|
|
def format_example(example):
|
|
"""Map common dataset schemas into a single text field."""
|
|
if "text" in example and example["text"]:
|
|
return {"text": example["text"]}
|
|
if "question" in example and "answer" in example:
|
|
return {"text": f"### Instruction:\n{example['question']}\n\n### Response:\n{example['answer']}"}
|
|
if "instruction" in example and "output" in example:
|
|
prefix = f"### Instruction:\n{example['instruction']}"
|
|
if example.get("input"):
|
|
prefix += f"\n\n### Input:\n{example['input']}"
|
|
return {"text": f"{prefix}\n\n### Response:\n{example['output']}"}
|
|
if "messages" in example and isinstance(example["messages"], list):
|
|
parts = []
|
|
for m in example["messages"]:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
parts.append(f"{role}: {content}")
|
|
return {"text": "\n".join(parts)}
|
|
raise ValueError(f"Unrecognized dataset schema. Keys: {list(example.keys())}")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
status("starting", args=vars(args))
|
|
|
|
status("importing")
|
|
import torch
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
from transformers import (
|
|
AutoModelForCausalLM, AutoTokenizer,
|
|
BitsAndBytesConfig, TrainingArguments, TrainerCallback,
|
|
)
|
|
from trl import SFTTrainer
|
|
|
|
if not torch.cuda.is_available():
|
|
status("error", detail="No CUDA GPU detected — training will fail.")
|
|
sys.exit(2)
|
|
|
|
gpu_name = torch.cuda.get_device_name(0)
|
|
gpu_mem_gb = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)
|
|
status("gpu_detected", name=gpu_name, memory_gb=gpu_mem_gb)
|
|
|
|
# ── Dataset ───────────────────────────────────────────────
|
|
status("loading_dataset", path=args.dataset)
|
|
suffix = Path(args.dataset).suffix.lower()
|
|
if suffix not in (".jsonl", ".json"):
|
|
raise ValueError(f"Dataset must be .jsonl or .json, got {suffix}")
|
|
raw = load_dataset("json", data_files=args.dataset, split="train")
|
|
raw = raw.map(format_example, remove_columns=[c for c in raw.column_names if c != "text"])
|
|
eval_ds = None
|
|
if args.eval_split > 0 and len(raw) >= 10:
|
|
split = raw.train_test_split(test_size=args.eval_split, seed=42)
|
|
train_ds, eval_ds = split["train"], split["test"]
|
|
else:
|
|
train_ds = raw
|
|
status("dataset_loaded", train_examples=len(train_ds),
|
|
eval_examples=len(eval_ds) if eval_ds else 0)
|
|
|
|
# ── Tokenizer ─────────────────────────────────────────────
|
|
status("loading_tokenizer", model=args.model)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# ── Model ─────────────────────────────────────────────────
|
|
bnb_config = None
|
|
if args.quantize == "4bit":
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
elif args.quantize == "8bit":
|
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
status("loading_model", model=args.model, quantize=args.quantize)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model,
|
|
quantization_config=bnb_config,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
if bnb_config:
|
|
model = prepare_model_for_kbit_training(model)
|
|
|
|
# ── LoRA ──────────────────────────────────────────────────
|
|
status("applying_lora", r=args.lora_r, alpha=args.lora_alpha)
|
|
lora_config = LoraConfig(
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj"],
|
|
)
|
|
model = get_peft_model(model, lora_config)
|
|
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
total = sum(p.numel() for p in model.parameters())
|
|
status("lora_applied", trainable_params=trainable, total_params=total,
|
|
trainable_pct=round(100 * trainable / total, 3))
|
|
|
|
# ── Trainer ───────────────────────────────────────────────
|
|
class JSONCallback(TrainerCallback):
|
|
def __init__(self):
|
|
self.start = time.time()
|
|
|
|
def on_log(self, args_, state, control, logs=None, **kw):
|
|
if not logs:
|
|
return
|
|
elapsed = time.time() - self.start
|
|
ev = {
|
|
"step": state.global_step,
|
|
"max_steps": state.max_steps,
|
|
"epoch": logs.get("epoch"),
|
|
"loss": logs.get("loss"),
|
|
"eval_loss": logs.get("eval_loss"),
|
|
"learning_rate": logs.get("learning_rate"),
|
|
"grad_norm": logs.get("grad_norm"),
|
|
"elapsed_s": round(elapsed, 1),
|
|
"progress_pct": round(state.global_step / max(state.max_steps, 1) * 100, 1),
|
|
}
|
|
if state.global_step > 0 and elapsed > 0:
|
|
steps_per_s = state.global_step / elapsed
|
|
remaining = max(state.max_steps - state.global_step, 0)
|
|
ev["eta_s"] = round(remaining / steps_per_s, 1) if steps_per_s else None
|
|
ev["steps_per_s"] = round(steps_per_s, 3)
|
|
_emit("METRIC", ev)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=args.output,
|
|
num_train_epochs=args.epochs,
|
|
per_device_train_batch_size=args.batch_size,
|
|
per_device_eval_batch_size=args.batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation,
|
|
learning_rate=args.lr,
|
|
bf16=True,
|
|
logging_steps=1,
|
|
save_steps=args.save_steps,
|
|
save_total_limit=2,
|
|
eval_strategy="steps" if eval_ds else "no",
|
|
eval_steps=args.save_steps if eval_ds else None,
|
|
optim="paged_adamw_8bit" if bnb_config else "adamw_torch",
|
|
warmup_ratio=args.warmup_ratio,
|
|
lr_scheduler_type=args.scheduler,
|
|
report_to="none",
|
|
gradient_checkpointing=True,
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
train_dataset=train_ds,
|
|
eval_dataset=eval_ds,
|
|
tokenizer=tokenizer,
|
|
args=training_args,
|
|
max_seq_length=args.max_seq_length,
|
|
dataset_text_field="text",
|
|
callbacks=[JSONCallback()],
|
|
)
|
|
|
|
status("training")
|
|
trainer.train()
|
|
|
|
status("saving", path=args.output)
|
|
model.save_pretrained(args.output)
|
|
tokenizer.save_pretrained(args.output)
|
|
|
|
status("done", adapter_path=args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except SystemExit:
|
|
raise
|
|
except Exception as e:
|
|
status("error", detail=str(e))
|
|
raise |