diff --git a/packaging/remote/train.py b/packaging/remote/train.py new file mode 100644 index 0000000..a6f3d58 --- /dev/null +++ b/packaging/remote/train.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +"""LoRA fine-tuning for causal LMs. + +Designed to be invoked from llm-trainer dashboard. Streams structured progress +events via stdout so the dashboard can render real metrics instead of tailing logs. + +Progress events are single-line JSON prefixed with [METRIC]: + [METRIC] {"type":"metric","step":12,"max_steps":120,"loss":1.83,...} + +Status events use [STATUS] for non-metric milestones: + [STATUS] {"phase":"loading_model","detail":"meta-llama/Llama-3-8B"} +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +# Defer heavy imports so --help is fast and missing-deps surface as a clean error. + +def _emit(kind, payload): + print(f"[{kind}] {json.dumps(payload)}", flush=True) + + +def status(phase, **extra): + _emit("STATUS", {"phase": phase, **extra}) + + +def parse_args(): + p = argparse.ArgumentParser(description="LoRA fine-tune a causal LM") + p.add_argument("--model", required=True, help="HF model id or local path") + p.add_argument("--dataset", required=True, help="Path to JSONL dataset") + p.add_argument("--output", required=True, help="Output directory") + p.add_argument("--epochs", type=int, default=3) + p.add_argument("--batch-size", type=int, default=2) + p.add_argument("--lr", type=float, default=2e-4) + p.add_argument("--lora-r", type=int, default=16) + p.add_argument("--lora-alpha", type=int, default=32) + p.add_argument("--lora-dropout", type=float, default=0.05) + p.add_argument("--max-seq-length", type=int, default=2048) + p.add_argument("--gradient-accumulation", type=int, default=4) + p.add_argument("--quantize", choices=["4bit", "8bit", "none"], default="4bit") + p.add_argument("--warmup-ratio", type=float, default=0.03) + p.add_argument("--scheduler", default="cosine", + choices=["linear", "cosine", "constant", "cosine_with_restarts"]) + p.add_argument("--save-steps", type=int, default=50) + p.add_argument("--eval-split", type=float, default=0.0, + help="Fraction held out for eval (0 disables)") + return p.parse_args() + + +def format_example(example): + """Map common dataset schemas into a single text field.""" + if "text" in example and example["text"]: + return {"text": example["text"]} + if "question" in example and "answer" in example: + return {"text": f"### Instruction:\n{example['question']}\n\n### Response:\n{example['answer']}"} + if "instruction" in example and "output" in example: + prefix = f"### Instruction:\n{example['instruction']}" + if example.get("input"): + prefix += f"\n\n### Input:\n{example['input']}" + return {"text": f"{prefix}\n\n### Response:\n{example['output']}"} + if "messages" in example and isinstance(example["messages"], list): + parts = [] + for m in example["messages"]: + role = m.get("role", "user") + content = m.get("content", "") + parts.append(f"{role}: {content}") + return {"text": "\n".join(parts)} + raise ValueError(f"Unrecognized dataset schema. Keys: {list(example.keys())}") + + +def main(): + args = parse_args() + status("starting", args=vars(args)) + + status("importing") + import torch + from datasets import load_dataset + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + from transformers import ( + AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, TrainingArguments, TrainerCallback, + ) + from trl import SFTTrainer + + if not torch.cuda.is_available(): + status("error", detail="No CUDA GPU detected — training will fail.") + sys.exit(2) + + gpu_name = torch.cuda.get_device_name(0) + gpu_mem_gb = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) + status("gpu_detected", name=gpu_name, memory_gb=gpu_mem_gb) + + # ── Dataset ─────────────────────────────────────────────── + status("loading_dataset", path=args.dataset) + suffix = Path(args.dataset).suffix.lower() + if suffix not in (".jsonl", ".json"): + raise ValueError(f"Dataset must be .jsonl or .json, got {suffix}") + raw = load_dataset("json", data_files=args.dataset, split="train") + raw = raw.map(format_example, remove_columns=[c for c in raw.column_names if c != "text"]) + eval_ds = None + if args.eval_split > 0 and len(raw) >= 10: + split = raw.train_test_split(test_size=args.eval_split, seed=42) + train_ds, eval_ds = split["train"], split["test"] + else: + train_ds = raw + status("dataset_loaded", train_examples=len(train_ds), + eval_examples=len(eval_ds) if eval_ds else 0) + + # ── Tokenizer ───────────────────────────────────────────── + status("loading_tokenizer", model=args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # ── Model ───────────────────────────────────────────────── + bnb_config = None + if args.quantize == "4bit": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + elif args.quantize == "8bit": + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + + status("loading_model", model=args.model, quantize=args.quantize) + model = AutoModelForCausalLM.from_pretrained( + args.model, + quantization_config=bnb_config, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + if bnb_config: + model = prepare_model_for_kbit_training(model) + + # ── LoRA ────────────────────────────────────────────────── + status("applying_lora", r=args.lora_r, alpha=args.lora_alpha) + lora_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + ) + model = get_peft_model(model, lora_config) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + status("lora_applied", trainable_params=trainable, total_params=total, + trainable_pct=round(100 * trainable / total, 3)) + + # ── Trainer ─────────────────────────────────────────────── + class JSONCallback(TrainerCallback): + def __init__(self): + self.start = time.time() + + def on_log(self, args_, state, control, logs=None, **kw): + if not logs: + return + elapsed = time.time() - self.start + ev = { + "step": state.global_step, + "max_steps": state.max_steps, + "epoch": logs.get("epoch"), + "loss": logs.get("loss"), + "eval_loss": logs.get("eval_loss"), + "learning_rate": logs.get("learning_rate"), + "grad_norm": logs.get("grad_norm"), + "elapsed_s": round(elapsed, 1), + "progress_pct": round(state.global_step / max(state.max_steps, 1) * 100, 1), + } + if state.global_step > 0 and elapsed > 0: + steps_per_s = state.global_step / elapsed + remaining = max(state.max_steps - state.global_step, 0) + ev["eta_s"] = round(remaining / steps_per_s, 1) if steps_per_s else None + ev["steps_per_s"] = round(steps_per_s, 3) + _emit("METRIC", ev) + + training_args = TrainingArguments( + output_dir=args.output, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation, + learning_rate=args.lr, + bf16=True, + logging_steps=1, + save_steps=args.save_steps, + save_total_limit=2, + eval_strategy="steps" if eval_ds else "no", + eval_steps=args.save_steps if eval_ds else None, + optim="paged_adamw_8bit" if bnb_config else "adamw_torch", + warmup_ratio=args.warmup_ratio, + lr_scheduler_type=args.scheduler, + report_to="none", + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + ) + + trainer = SFTTrainer( + model=model, + train_dataset=train_ds, + eval_dataset=eval_ds, + tokenizer=tokenizer, + args=training_args, + max_seq_length=args.max_seq_length, + dataset_text_field="text", + callbacks=[JSONCallback()], + ) + + status("training") + trainer.train() + + status("saving", path=args.output) + model.save_pretrained(args.output) + tokenizer.save_pretrained(args.output) + + status("done", adapter_path=args.output) + + +if __name__ == "__main__": + try: + main() + except SystemExit: + raise + except Exception as e: + status("error", detail=str(e)) + raise \ No newline at end of file