Lint and format

This commit is contained in:
NanoCode012
2023-05-29 03:45:42 +09:00
parent a98deb31a6
commit 392dfd9b07
9 changed files with 82 additions and 58 deletions

View File

@@ -1,3 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -37,7 +40,7 @@ def choose_device(cfg):
try:
if torch.backends.mps.is_available():
return "mps"
except:
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
@@ -130,12 +133,12 @@ def train(
config = choose_config(config)
# load the config from the yaml file
with open(config, "r") as f:
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k in kwargs:
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
# handle booleans
@@ -167,13 +170,11 @@ def train(
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -262,10 +263,13 @@ def train(
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# pylint: disable=fixme
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
# pylint: disable=fixme
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time