Compare commits
1 Commits
optimizers
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08df47584 |
@@ -10,9 +10,9 @@ strict: false
|
|||||||
|
|
||||||
max_steps: 200
|
max_steps: 200
|
||||||
pretraining_dataset:
|
pretraining_dataset:
|
||||||
path: c4
|
- path: c4
|
||||||
name: en
|
name: en
|
||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
|
|||||||
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from typing import Callable, Generator, Tuple
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
import psycopg.conninfo
|
||||||
|
|
||||||
|
|
||||||
|
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||||
|
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||||
|
if not pgsql_conn:
|
||||||
|
raise ValueError("missing PGSQL_CONN environment variable")
|
||||||
|
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||||
|
|
||||||
|
def data_generator() -> Generator[Tuple, None, None]:
|
||||||
|
with psycopg.connect(**conn_dict) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
page_size = 10
|
||||||
|
last_id = None
|
||||||
|
while True:
|
||||||
|
if last_id:
|
||||||
|
where_clause = f" WHERE {id_field} > {last_id}"
|
||||||
|
cur.execute(
|
||||||
|
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||||
|
)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
yield row[id_field], dict(row)
|
||||||
|
|
||||||
|
return data_generator
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing data utilities"""
|
"""Module containing data utilities"""
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -11,10 +12,12 @@ import yaml
|
|||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
|
IterableDataset,
|
||||||
concatenate_datasets,
|
concatenate_datasets,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
|
from datasets.iterable_dataset import ExamplesIterable
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import HFValidationError
|
from huggingface_hub.utils import HFValidationError
|
||||||
from torch.utils.data import RandomSampler
|
from torch.utils.data import RandomSampler
|
||||||
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_dataset(ds_cfg):
|
||||||
|
path = ds_cfg["path"]
|
||||||
|
func = None
|
||||||
|
try:
|
||||||
|
load_fn = path.split(".")[-1]
|
||||||
|
module_name = ".".join(load_fn.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if func:
|
||||||
|
data_producer = func(**ds_cfg)
|
||||||
|
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||||
|
else:
|
||||||
|
split = ds_cfg["split"] or "train"
|
||||||
|
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
|
||||||
name = None
|
|
||||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
|
||||||
cfg.pretraining_dataset[0], dict
|
|
||||||
):
|
|
||||||
path = cfg.pretraining_dataset[0]["path"]
|
|
||||||
name = cfg.pretraining_dataset[0]["name"]
|
|
||||||
|
|
||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
load_dataset(path, streaming=True, split="train", name=name),
|
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
|
|||||||
Reference in New Issue
Block a user