Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e08df47584 wip load remote data from postgres 2024-02-12 09:55:24 -05:00
5 changed files with 54 additions and 12 deletions

View File

@@ -10,9 +10,9 @@ strict: false
max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
- path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

View File

View File

@@ -0,0 +1,28 @@
import os
from typing import Callable, Generator, Tuple
import psycopg
import psycopg.conninfo
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
pgsql_conn = os.environ.get("PGSQL_CONN", None)
if not pgsql_conn:
raise ValueError("missing PGSQL_CONN environment variable")
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
def data_generator() -> Generator[Tuple, None, None]:
with psycopg.connect(**conn_dict) as conn:
with conn.cursor() as cur:
page_size = 10
last_id = None
while True:
if last_id:
where_clause = f" WHERE {id_field} > {last_id}"
cur.execute(
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
)
for row in cur.fetchall():
yield row[id_field], dict(row)
return data_generator

View File

@@ -1,6 +1,7 @@
"""Module containing data utilities"""
import functools
import hashlib
import importlib
import logging
from collections import defaultdict
from pathlib import Path
@@ -11,10 +12,12 @@ import yaml
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from datasets.iterable_dataset import ExamplesIterable
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def get_streaming_dataset(ds_cfg):
path = ds_cfg["path"]
func = None
try:
load_fn = path.split(".")[-1]
module_name = ".".join(load_fn.split(".")[:-1])
mod = importlib.import_module(f".{module_name}", "axolotl")
func = getattr(mod, load_fn)
except Exception:
pass
if func:
data_producer = func(**ds_cfg)
return IterableDataset(ExamplesIterable(data_producer, {}))
else:
split = ds_cfg["split"] or "train"
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
get_streaming_dataset(cfg.pretraining_dataset[0]),
tokenizer,
cfg,
ds_wrapper_partial,