#!/usr/bin/env python3 """LoRA fine-tuning for causal LMs. Designed to be invoked from llm-trainer dashboard. Streams structured progress events via stdout so the dashboard can render real metrics instead of tailing logs. Progress events are single-line JSON prefixed with [METRIC]: [METRIC] {"type":"metric","step":12,"max_steps":120,"loss":1.83,...} Status events use [STATUS] for non-metric milestones: [STATUS] {"phase":"loading_model","detail":"meta-llama/Llama-3-8B"} """ import argparse import json import os import sys import time from pathlib import Path # Defer heavy imports so --help is fast and missing-deps surface as a clean error. def _emit(kind, payload): print(f"[{kind}] {json.dumps(payload)}", flush=True) def status(phase, **extra): _emit("STATUS", {"phase": phase, **extra}) def parse_args(): p = argparse.ArgumentParser(description="LoRA fine-tune a causal LM") p.add_argument("--model", required=True, help="HF model id or local path") p.add_argument("--dataset", required=True, help="Path to JSONL dataset") p.add_argument("--output", required=True, help="Output directory") p.add_argument("--epochs", type=int, default=3) p.add_argument("--batch-size", type=int, default=2) p.add_argument("--lr", type=float, default=2e-4) p.add_argument("--lora-r", type=int, default=16) p.add_argument("--lora-alpha", type=int, default=32) p.add_argument("--lora-dropout", type=float, default=0.05) p.add_argument("--max-seq-length", type=int, default=2048) p.add_argument("--gradient-accumulation", type=int, default=4) p.add_argument("--quantize", choices=["4bit", "8bit", "none"], default="4bit") p.add_argument("--warmup-ratio", type=float, default=0.03) p.add_argument("--scheduler", default="cosine", choices=["linear", "cosine", "constant", "cosine_with_restarts"]) p.add_argument("--save-steps", type=int, default=50) p.add_argument("--eval-split", type=float, default=0.0, help="Fraction held out for eval (0 disables)") return p.parse_args() def format_example(example): """Map common dataset schemas into a single text field.""" if "text" in example and example["text"]: return {"text": example["text"]} if "question" in example and "answer" in example: return {"text": f"### Instruction:\n{example['question']}\n\n### Response:\n{example['answer']}"} if "instruction" in example and "output" in example: prefix = f"### Instruction:\n{example['instruction']}" if example.get("input"): prefix += f"\n\n### Input:\n{example['input']}" return {"text": f"{prefix}\n\n### Response:\n{example['output']}"} if "messages" in example and isinstance(example["messages"], list): parts = [] for m in example["messages"]: role = m.get("role", "user") content = m.get("content", "") parts.append(f"{role}: {content}") return {"text": "\n".join(parts)} raise ValueError(f"Unrecognized dataset schema. Keys: {list(example.keys())}") def main(): args = parse_args() status("starting", args=vars(args)) status("importing") import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, TrainerCallback, ) from trl import SFTTrainer if not torch.cuda.is_available(): status("error", detail="No CUDA GPU detected — training will fail.") sys.exit(2) gpu_name = torch.cuda.get_device_name(0) gpu_mem_gb = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) status("gpu_detected", name=gpu_name, memory_gb=gpu_mem_gb) # ── Dataset ─────────────────────────────────────────────── status("loading_dataset", path=args.dataset) suffix = Path(args.dataset).suffix.lower() if suffix not in (".jsonl", ".json"): raise ValueError(f"Dataset must be .jsonl or .json, got {suffix}") raw = load_dataset("json", data_files=args.dataset, split="train") raw = raw.map(format_example, remove_columns=[c for c in raw.column_names if c != "text"]) eval_ds = None if args.eval_split > 0 and len(raw) >= 10: split = raw.train_test_split(test_size=args.eval_split, seed=42) train_ds, eval_ds = split["train"], split["test"] else: train_ds = raw status("dataset_loaded", train_examples=len(train_ds), eval_examples=len(eval_ds) if eval_ds else 0) # ── Tokenizer ───────────────────────────────────────────── status("loading_tokenizer", model=args.model) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # ── Model ───────────────────────────────────────────────── bnb_config = None if args.quantize == "4bit": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) elif args.quantize == "8bit": bnb_config = BitsAndBytesConfig(load_in_8bit=True) status("loading_model", model=args.model, quantize=args.quantize) model = AutoModelForCausalLM.from_pretrained( args.model, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, ) if bnb_config: model = prepare_model_for_kbit_training(model) # ── LoRA ────────────────────────────────────────────────── status("applying_lora", r=args.lora_r, alpha=args.lora_alpha) lora_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora_config) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) status("lora_applied", trainable_params=trainable, total_params=total, trainable_pct=round(100 * trainable / total, 3)) # ── Trainer ─────────────────────────────────────────────── class JSONCallback(TrainerCallback): def __init__(self): self.start = time.time() def on_log(self, args_, state, control, logs=None, **kw): if not logs: return elapsed = time.time() - self.start ev = { "step": state.global_step, "max_steps": state.max_steps, "epoch": logs.get("epoch"), "loss": logs.get("loss"), "eval_loss": logs.get("eval_loss"), "learning_rate": logs.get("learning_rate"), "grad_norm": logs.get("grad_norm"), "elapsed_s": round(elapsed, 1), "progress_pct": round(state.global_step / max(state.max_steps, 1) * 100, 1), } if state.global_step > 0 and elapsed > 0: steps_per_s = state.global_step / elapsed remaining = max(state.max_steps - state.global_step, 0) ev["eta_s"] = round(remaining / steps_per_s, 1) if steps_per_s else None ev["steps_per_s"] = round(steps_per_s, 3) _emit("METRIC", ev) training_args = TrainingArguments( output_dir=args.output, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, learning_rate=args.lr, bf16=True, logging_steps=1, save_steps=args.save_steps, save_total_limit=2, eval_strategy="steps" if eval_ds else "no", eval_steps=args.save_steps if eval_ds else None, optim="paged_adamw_8bit" if bnb_config else "adamw_torch", warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.scheduler, report_to="none", gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, ) trainer = SFTTrainer( model=model, train_dataset=train_ds, eval_dataset=eval_ds, tokenizer=tokenizer, args=training_args, max_seq_length=args.max_seq_length, dataset_text_field="text", callbacks=[JSONCallback()], ) status("training") trainer.train() status("saving", path=args.output) model.save_pretrained(args.output) tokenizer.save_pretrained(args.output) status("done", adapter_path=args.output) if __name__ == "__main__": try: main() except SystemExit: raise except Exception as e: status("error", detail=str(e)) raise