Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e08df47584 wip load remote data from postgres 2024-02-12 09:55:24 -05:00
5 changed files with 54 additions and 12 deletions

View File

@@ -10,9 +10,9 @@ strict: false
max_steps: 200 max_steps: 200
pretraining_dataset: pretraining_dataset:
path: c4 - path: c4
name: en name: en
type: pretrain type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./model-out

View File

View File

@@ -0,0 +1,28 @@
import os
from typing import Callable, Generator, Tuple
import psycopg
import psycopg.conninfo
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
pgsql_conn = os.environ.get("PGSQL_CONN", None)
if not pgsql_conn:
raise ValueError("missing PGSQL_CONN environment variable")
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
def data_generator() -> Generator[Tuple, None, None]:
with psycopg.connect(**conn_dict) as conn:
with conn.cursor() as cur:
page_size = 10
last_id = None
while True:
if last_id:
where_clause = f" WHERE {id_field} > {last_id}"
cur.execute(
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
)
for row in cur.fetchall():
yield row[id_field], dict(row)
return data_generator

View File

@@ -1,6 +1,7 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import importlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@@ -11,10 +12,12 @@ import yaml
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
IterableDataset,
concatenate_datasets, concatenate_datasets,
load_dataset, load_dataset,
load_from_disk, load_from_disk,
) )
from datasets.iterable_dataset import ExamplesIterable
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def get_streaming_dataset(ds_cfg):
path = ds_cfg["path"]
func = None
try:
load_fn = path.split(".")[-1]
module_name = ".".join(load_fn.split(".")[:-1])
mod = importlib.import_module(f".{module_name}", "axolotl")
func = getattr(mod, load_fn)
except Exception:
pass
if func:
data_producer = func(**ds_cfg)
return IterableDataset(ExamplesIterable(data_producer, {}))
else:
split = ds_cfg["split"] or "train"
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name), get_streaming_dataset(cfg.pretraining_dataset[0]),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,