Apply isort then black

This commit is contained in:
NanoCode012
2023-05-29 18:48:58 +09:00
parent 96e8378692
commit 37293dce07
15 changed files with 158 additions and 97 deletions

View File

@@ -2,23 +2,20 @@
import os
import sys
from typing import Optional, Union
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
StdoutWriter,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")

View File

@@ -7,20 +7,20 @@ import random
import signal
import sys
from pathlib import Path
from typing import Optional, List, Dict, Any, Union
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -242,7 +242,10 @@ def train(
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)),
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
)
logging.info("Starting trainer...")
@@ -255,7 +258,8 @@ def train(
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(