more logging, wandb fixes
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@@ -37,6 +38,9 @@ from axolotl.prompt_tokenizers import (
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
||||
|
||||
|
||||
def setup_wandb_env_vars(cfg):
|
||||
if len(cfg.wandb_project) > 0:
|
||||
@@ -46,6 +50,8 @@ def setup_wandb_env_vars(cfg):
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
|
||||
|
||||
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
||||
@@ -164,8 +170,8 @@ def check_dataset_labels(dataset, tokenizer):
|
||||
)
|
||||
colored_tokens.append(colored_token)
|
||||
|
||||
print(" ".join(colored_tokens))
|
||||
print("\n\n\n")
|
||||
logger.info(" ".join(colored_tokens))
|
||||
logger.info("\n\n\n")
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer):
|
||||
@@ -247,7 +253,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_name if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -341,9 +347,9 @@ def train(
|
||||
return
|
||||
|
||||
if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")):
|
||||
print("Loading prepared dataset from disk...")
|
||||
dataset = load_from_disk(cfg.datasets)
|
||||
print("Prepared dataset loaded from disk...")
|
||||
logger.info("Loading prepared dataset from disk...")
|
||||
dataset = load_from_disk(cfg.dataset_prepared_path)
|
||||
logger.info("Prepared dataset loaded from disk...")
|
||||
else:
|
||||
datasets = []
|
||||
for d in cfg.datasets:
|
||||
@@ -376,11 +382,12 @@ def train(
|
||||
[_ for _ in constant_len_dataset]
|
||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||
|
||||
print("Saving prepared dataset to disk...")
|
||||
if cfg.dataset_prepared_path:
|
||||
dataset.save_to_disk(cfg.dataset_prepared_path)
|
||||
else:
|
||||
dataset.save_to_disk("data/last_run")
|
||||
if cfg.local_rank == 0:
|
||||
logger.info("Saving prepared dataset to disk...")
|
||||
if cfg.dataset_prepared_path:
|
||||
dataset.save_to_disk(cfg.dataset_prepared_path)
|
||||
else:
|
||||
dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
@@ -396,9 +403,11 @@ def train(
|
||||
model.config.use_cache = False
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
logger.info("Compiling torch model")
|
||||
model = torch.compile(model)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
logger.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
lora_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
@@ -407,9 +416,11 @@ def train(
|
||||
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
||||
)
|
||||
|
||||
logger.info("Starting trainer...")
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
logger.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user