Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
e08df47584 wip load remote data from postgres 2024-02-12 09:55:24 -05:00
9 changed files with 58 additions and 158 deletions

View File

@@ -10,9 +10,9 @@ strict: false
max_steps: 200 max_steps: 200
pretraining_dataset: pretraining_dataset:
path: c4 - path: c4
name: en name: en
type: pretrain type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./model-out

View File

@@ -49,7 +49,7 @@ from axolotl.utils.collators import (
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler, get_cosine_schedule_with_quadratic_warmup,
) )
try: try:
@@ -129,19 +129,7 @@ class AxolotlTrainingArguments(TrainingArguments):
) )
relora_anneal_steps: Optional[int] = field( relora_anneal_steps: Optional[int] = field(
default=None, default=None,
metadata={"help": "how many anneal steps to take before reset for ReLoRA"}, metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
) )
bench_split: Optional[str] = field( bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"} default="eval", metadata={"help": "The benchmark split to run on"}
@@ -238,7 +226,7 @@ class AxolotlTrainer(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio, min_lr_ratio=self.args.cosine_min_lr_ratio,
) )
else: else:
super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -246,21 +234,6 @@ class AxolotlTrainer(Trainer):
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
@@ -900,8 +873,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )
if self.cfg.save_only_model:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler self.cfg.lr_scheduler
if self.cfg.lr_scheduler if self.cfg.lr_scheduler

View File

View File

@@ -0,0 +1,28 @@
import os
from typing import Callable, Generator, Tuple
import psycopg
import psycopg.conninfo
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
pgsql_conn = os.environ.get("PGSQL_CONN", None)
if not pgsql_conn:
raise ValueError("missing PGSQL_CONN environment variable")
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
def data_generator() -> Generator[Tuple, None, None]:
with psycopg.connect(**conn_dict) as conn:
with conn.cursor() as cur:
page_size = 10
last_id = None
while True:
if last_id:
where_clause = f" WHERE {id_field} > {last_id}"
cur.execute(
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
)
for row in cur.fetchall():
yield row[id_field], dict(row)
return data_generator

View File

@@ -1,67 +0,0 @@
from typing import Optional, Dict, Any
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation, truncation=True, max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids)
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(ds_cfg["conversation"]),
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -62,7 +62,7 @@ class EvalFirstStepCallback(
): ):
if ( if (
args.evaluation_strategy == IntervalStrategy.STEPS args.evaluation_strategy == IntervalStrategy.STEPS
and (args.eval_steps < 1.0 or args.eval_steps > 1) and args.eval_steps < 1.0
and state.global_step == 1 and state.global_step == 1
): ):
control.should_evaluate = True control.should_evaluate = True

View File

@@ -1,6 +1,7 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import importlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@@ -11,10 +12,12 @@ import yaml
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
IterableDataset,
concatenate_datasets, concatenate_datasets,
load_dataset, load_dataset,
load_from_disk, load_from_disk,
) )
from datasets.iterable_dataset import ExamplesIterable
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def get_streaming_dataset(ds_cfg):
path = ds_cfg["path"]
func = None
try:
load_fn = path.split(".")[-1]
module_name = ".".join(load_fn.split(".")[:-1])
mod = importlib.import_module(f".{module_name}", "axolotl")
func = getattr(mod, load_fn)
except Exception:
pass
if func:
data_producer = func(**ds_cfg)
return IterableDataset(ExamplesIterable(data_producer, {}))
else:
split = ds_cfg["split"] or "train"
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name), get_streaming_dataset(cfg.pretraining_dataset[0]),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,

View File

@@ -1,7 +1,6 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math import math
from functools import partial from functools import partial
from typing import Sequence
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -141,48 +140,3 @@ def get_cosine_schedule_with_min_lr(
min_lr_ratio=min_lr_ratio, min_lr_ratio=min_lr_ratio,
) )
return LambdaLR(optimizer, lr_lambda) return LambdaLR(optimizer, lr_lambda)
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps:
scale = 1
else:
per_relora_progress = step % self.restarts_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale