diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index 874bbcf52..2e31c9e2e 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -10,9 +10,9 @@ strict: false max_steps: 200 pretraining_dataset: - path: c4 - name: en - type: pretrain + - path: c4 + name: en + type: pretrain dataset_prepared_path: val_set_size: 0.0 output_dir: ./model-out diff --git a/src/axolotl/plugins/oaaic/__init__.py b/src/axolotl/plugins/oaaic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/plugins/oaaic/data/__init__.py b/src/axolotl/plugins/oaaic/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/plugins/oaaic/data/streaming_sql.py b/src/axolotl/plugins/oaaic/data/streaming_sql.py new file mode 100644 index 000000000..353c2be2d --- /dev/null +++ b/src/axolotl/plugins/oaaic/data/streaming_sql.py @@ -0,0 +1,28 @@ +import os +from typing import Callable, Generator, Tuple + +import psycopg +import psycopg.conninfo + + +def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable: + pgsql_conn = os.environ.get("PGSQL_CONN", None) + if not pgsql_conn: + raise ValueError("missing PGSQL_CONN environment variable") + conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn) + + def data_generator() -> Generator[Tuple, None, None]: + with psycopg.connect(**conn_dict) as conn: + with conn.cursor() as cur: + page_size = 10 + last_id = None + while True: + if last_id: + where_clause = f" WHERE {id_field} > {last_id}" + cur.execute( + f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}" + ) + for row in cur.fetchall(): + yield row[id_field], dict(row) + + return data_generator diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 66a9b0a71..f78e682a1 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,6 +1,7 @@ """Module containing data utilities""" import functools import hashlib +import importlib import logging from collections import defaultdict from pathlib import Path @@ -11,10 +12,12 @@ import yaml from datasets import ( Dataset, DatasetDict, + IterableDataset, concatenate_datasets, load_dataset, load_from_disk, ) +from datasets.iterable_dataset import ExamplesIterable from huggingface_hub import hf_hub_download from huggingface_hub.utils import HFValidationError from torch.utils.data import RandomSampler @@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec +def get_streaming_dataset(ds_cfg): + path = ds_cfg["path"] + func = None + try: + load_fn = path.split(".")[-1] + module_name = ".".join(load_fn.split(".")[:-1]) + mod = importlib.import_module(f".{module_name}", "axolotl") + func = getattr(mod, load_fn) + except Exception: + pass + + if func: + data_producer = func(**ds_cfg) + return IterableDataset(ExamplesIterable(data_producer, {})) + else: + split = ds_cfg["split"] or "train" + return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"]) + + def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: @@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: - path = cfg.pretraining_dataset - name = None - if isinstance(cfg.pretraining_dataset, list) and isinstance( - cfg.pretraining_dataset[0], dict - ): - path = cfg.pretraining_dataset[0]["path"] - name = cfg.pretraining_dataset[0]["name"] - ds_wrapper_partial = functools.partial( get_dataset_wrapper, cfg.pretraining_dataset[0], @@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer): ) train_dataset = wrap_pretraining_dataset( - load_dataset(path, streaming=True, split="train", name=name), + get_streaming_dataset(cfg.pretraining_dataset[0]), tokenizer, cfg, ds_wrapper_partial,