configure log level, add llama 7b config

This commit is contained in:
Wing Lian
2023-04-15 14:24:37 -04:00
parent 05fffb53b4
commit d33a975747

View File

@@ -39,6 +39,7 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
@@ -353,9 +354,15 @@ def train(
else:
datasets = []
for d in cfg.datasets:
ds: IterableDataset = load_dataset(
"json", data_files=d.path, streaming=True, split=None
)
if Path(d.path).exists():
ds: IterableDataset = load_dataset(
"json", data_files=d.path, streaming=True, split=None
)
# elif d.name and d.path:
# # TODO load from huggingface hub, but it only seems to support arrow or parquet atm
# ds = load_dataset(d.path, split=None, data_files=d.name)
else:
raise Exception("unhandled dataset load")
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(