configure log level, add llama 7b config
This commit is contained in:
@@ -39,6 +39,7 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
||||
|
||||
|
||||
@@ -353,9 +354,15 @@ def train(
|
||||
else:
|
||||
datasets = []
|
||||
for d in cfg.datasets:
|
||||
ds: IterableDataset = load_dataset(
|
||||
"json", data_files=d.path, streaming=True, split=None
|
||||
)
|
||||
if Path(d.path).exists():
|
||||
ds: IterableDataset = load_dataset(
|
||||
"json", data_files=d.path, streaming=True, split=None
|
||||
)
|
||||
# elif d.name and d.path:
|
||||
# # TODO load from huggingface hub, but it only seems to support arrow or parquet atm
|
||||
# ds = load_dataset(d.path, split=None, data_files=d.name)
|
||||
else:
|
||||
raise Exception("unhandled dataset load")
|
||||
|
||||
if d.type == "alpaca":
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
|
||||
Reference in New Issue
Block a user