Compare commits
1 Commits
fix/diffus
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08df47584 |
@@ -10,9 +10,9 @@ strict: false
|
||||
|
||||
max_steps: 200
|
||||
pretraining_dataset:
|
||||
path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
- path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
from typing import Callable, Generator, Tuple
|
||||
|
||||
import psycopg
|
||||
import psycopg.conninfo
|
||||
|
||||
|
||||
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||
if not pgsql_conn:
|
||||
raise ValueError("missing PGSQL_CONN environment variable")
|
||||
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||
|
||||
def data_generator() -> Generator[Tuple, None, None]:
|
||||
with psycopg.connect(**conn_dict) as conn:
|
||||
with conn.cursor() as cur:
|
||||
page_size = 10
|
||||
last_id = None
|
||||
while True:
|
||||
if last_id:
|
||||
where_clause = f" WHERE {id_field} > {last_id}"
|
||||
cur.execute(
|
||||
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
yield row[id_field], dict(row)
|
||||
|
||||
return data_generator
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -11,10 +12,12 @@ import yaml
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from datasets.iterable_dataset import ExamplesIterable
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HFValidationError
|
||||
from torch.utils.data import RandomSampler
|
||||
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||
|
||||
|
||||
def get_streaming_dataset(ds_cfg):
|
||||
path = ds_cfg["path"]
|
||||
func = None
|
||||
try:
|
||||
load_fn = path.split(".")[-1]
|
||||
module_name = ".".join(load_fn.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||
func = getattr(mod, load_fn)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if func:
|
||||
data_producer = func(**ds_cfg)
|
||||
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||
else:
|
||||
split = ds_cfg["split"] or "train"
|
||||
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
path = cfg.pretraining_dataset
|
||||
name = None
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
)
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split="train", name=name),
|
||||
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
|
||||
Reference in New Issue
Block a user