fix logging

This commit is contained in:
Wing Lian
2023-04-15 23:12:48 -04:00
parent 23938015c8
commit a4593832a9

View File

@@ -38,8 +38,7 @@ from axolotl.prompt_tokenizers import (
) )
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
logger = logging.getLogger(__name__) logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
logger.setLevel(os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "data/last_run" DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
@@ -171,8 +170,8 @@ def check_dataset_labels(dataset, tokenizer):
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)
logger.info(" ".join(colored_tokens)) logging.info(" ".join(colored_tokens))
logger.info("\n\n\n") logging.info("\n\n\n")
def do_inference(cfg, model, tokenizer): def do_inference(cfg, model, tokenizer):
@@ -349,9 +348,9 @@ def train(
return return
if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")): if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")):
logger.info("Loading prepared dataset from disk...") logging.info("Loading prepared dataset from disk...")
dataset = load_from_disk(cfg.dataset_prepared_path) dataset = load_from_disk(cfg.dataset_prepared_path)
logger.info("Prepared dataset loaded from disk...") logging.info("Prepared dataset loaded from disk...")
else: else:
datasets = [] datasets = []
for d in cfg.datasets: for d in cfg.datasets:
@@ -391,14 +390,14 @@ def train(
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
if cfg.local_rank == 0: if cfg.local_rank == 0:
logger.info("Saving prepared dataset to disk...") logging.info("Saving prepared dataset to disk...")
if cfg.dataset_prepared_path: if cfg.dataset_prepared_path:
dataset.save_to_disk(cfg.dataset_prepared_path) dataset.save_to_disk(cfg.dataset_prepared_path)
else: else:
dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH) dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH)
if prepare_ds_only: if prepare_ds_only:
logger.info("Finished preparing dataset. Exiting...") logging.info("Finished preparing dataset. Exiting...")
return return
train_dataset = dataset["train"] train_dataset = dataset["train"]
@@ -415,11 +414,11 @@ def train(
model.config.use_cache = False model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
logger.info("Compiling torch model") logging.info("Compiling torch model")
model = torch.compile(model) model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
logger.info(f"Pre-saving adapter config to {cfg.output_dir}") logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
lora_config.save_pretrained(cfg.output_dir) lora_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
@@ -428,11 +427,11 @@ def train(
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
) )
logger.info("Starting trainer...") logging.info("Starting trainer...")
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
logger.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)