Compare commits
1 Commits
20240216-u
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08df47584 |
@@ -10,9 +10,9 @@ strict: false
|
|||||||
|
|
||||||
max_steps: 200
|
max_steps: 200
|
||||||
pretraining_dataset:
|
pretraining_dataset:
|
||||||
path: c4
|
- path: c4
|
||||||
name: en
|
name: en
|
||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from axolotl.utils.collators import (
|
|||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -129,19 +129,7 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
)
|
)
|
||||||
relora_anneal_steps: Optional[int] = field(
|
relora_anneal_steps: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
|
||||||
jagged_restart_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for jagged restarts"},
|
|
||||||
)
|
|
||||||
jagged_restarts_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
|
|
||||||
)
|
|
||||||
jagged_restarts_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
|
|
||||||
)
|
)
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
@@ -238,7 +226,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
@@ -246,21 +234,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
if use_cosine_min_lr:
|
if use_cosine_min_lr:
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
if self.args.jagged_restart_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.jagged_restarts_warmup_steps or 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.jagged_restarts_anneal_steps or 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = JaggedLRRestartScheduler(
|
|
||||||
optimizer,
|
|
||||||
self.lr_scheduler,
|
|
||||||
self.args.jagged_restart_steps,
|
|
||||||
warmup_steps,
|
|
||||||
anneal_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
@@ -900,8 +873,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
)
|
)
|
||||||
if self.cfg.save_only_model:
|
|
||||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler
|
self.cfg.lr_scheduler
|
||||||
if self.cfg.lr_scheduler
|
if self.cfg.lr_scheduler
|
||||||
|
|||||||
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
0
src/axolotl/plugins/oaaic/data/__init__.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
28
src/axolotl/plugins/oaaic/data/streaming_sql.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from typing import Callable, Generator, Tuple
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
import psycopg.conninfo
|
||||||
|
|
||||||
|
|
||||||
|
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||||
|
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||||
|
if not pgsql_conn:
|
||||||
|
raise ValueError("missing PGSQL_CONN environment variable")
|
||||||
|
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||||
|
|
||||||
|
def data_generator() -> Generator[Tuple, None, None]:
|
||||||
|
with psycopg.connect(**conn_dict) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
page_size = 10
|
||||||
|
last_id = None
|
||||||
|
while True:
|
||||||
|
if last_id:
|
||||||
|
where_clause = f" WHERE {id_field} > {last_id}"
|
||||||
|
cur.execute(
|
||||||
|
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||||
|
)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
yield row[id_field], dict(row)
|
||||||
|
|
||||||
|
return data_generator
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import Prompter
|
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
|
||||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.chat_template = chat_template
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
|
||||||
return self.tokenizer.apply_chat_template(
|
|
||||||
conversation, truncation=True, max_length=self.max_length,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
Tokenizing strategy for instruction-based prompts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
|
||||||
turns = self.get_conversation_thread(prompt)
|
|
||||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
|
||||||
user_prompt_len = len(prompt_ids)
|
|
||||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
|
||||||
else:
|
|
||||||
labels = input_ids
|
|
||||||
|
|
||||||
|
|
||||||
tokenized_prompt = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
"attention_mask": [1] * len(input_ids)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenized_prompt
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap roles - allow for assistant turn
|
|
||||||
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
|
|
||||||
turns = [
|
|
||||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(
|
|
||||||
tokenizer,
|
|
||||||
chat_templates(ds_cfg["conversation"]),
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
return strategy
|
|
||||||
@@ -62,7 +62,7 @@ class EvalFirstStepCallback(
|
|||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||||
and (args.eval_steps < 1.0 or args.eval_steps > 1)
|
and args.eval_steps < 1.0
|
||||||
and state.global_step == 1
|
and state.global_step == 1
|
||||||
):
|
):
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing data utilities"""
|
"""Module containing data utilities"""
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -11,10 +12,12 @@ import yaml
|
|||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
|
IterableDataset,
|
||||||
concatenate_datasets,
|
concatenate_datasets,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
|
from datasets.iterable_dataset import ExamplesIterable
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import HFValidationError
|
from huggingface_hub.utils import HFValidationError
|
||||||
from torch.utils.data import RandomSampler
|
from torch.utils.data import RandomSampler
|
||||||
@@ -64,6 +67,25 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_dataset(ds_cfg):
|
||||||
|
path = ds_cfg["path"]
|
||||||
|
func = None
|
||||||
|
try:
|
||||||
|
load_fn = path.split(".")[-1]
|
||||||
|
module_name = ".".join(load_fn.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if func:
|
||||||
|
data_producer = func(**ds_cfg)
|
||||||
|
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||||
|
else:
|
||||||
|
split = ds_cfg["split"] or "train"
|
||||||
|
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
@@ -80,14 +102,6 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
|
||||||
name = None
|
|
||||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
|
||||||
cfg.pretraining_dataset[0], dict
|
|
||||||
):
|
|
||||||
path = cfg.pretraining_dataset[0]["path"]
|
|
||||||
name = cfg.pretraining_dataset[0]["name"]
|
|
||||||
|
|
||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
@@ -97,7 +111,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
load_dataset(path, streaming=True, split="train", name=name),
|
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for custom LRScheduler class"""
|
"""Module for custom LRScheduler class"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
@@ -141,48 +140,3 @@ def get_cosine_schedule_with_min_lr(
|
|||||||
min_lr_ratio=min_lr_ratio,
|
min_lr_ratio=min_lr_ratio,
|
||||||
)
|
)
|
||||||
return LambdaLR(optimizer, lr_lambda)
|
return LambdaLR(optimizer, lr_lambda)
|
||||||
|
|
||||||
|
|
||||||
class JaggedLRRestartScheduler(LRScheduler):
|
|
||||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
inner_schedule: LRScheduler,
|
|
||||||
jagged_restarts_steps: int,
|
|
||||||
jagged_restarts_warmup_steps: int,
|
|
||||||
jagged_restarts_anneal_steps: int = 1,
|
|
||||||
min_lr_scale: float = 0.001,
|
|
||||||
) -> None:
|
|
||||||
self.inner_schedule = inner_schedule
|
|
||||||
self.restarts_steps = jagged_restarts_steps
|
|
||||||
self.warmup_steps = jagged_restarts_warmup_steps
|
|
||||||
self.anneal_steps = jagged_restarts_anneal_steps
|
|
||||||
self.min_lr_scale = min_lr_scale
|
|
||||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
|
||||||
|
|
||||||
def get_lr(self) -> float:
|
|
||||||
self.inner_schedule.last_epoch = self.last_epoch
|
|
||||||
|
|
||||||
original = self.inner_schedule.get_lr()
|
|
||||||
step = self.last_epoch
|
|
||||||
|
|
||||||
if step < self.restarts_steps:
|
|
||||||
scale = 1
|
|
||||||
else:
|
|
||||||
per_relora_progress = step % self.restarts_steps
|
|
||||||
if per_relora_progress < self.warmup_steps:
|
|
||||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
|
||||||
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
|
|
||||||
cycle_t = min(
|
|
||||||
1.0,
|
|
||||||
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cycle_t = 1
|
|
||||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
|
||||||
|
|
||||||
if isinstance(original, Sequence):
|
|
||||||
return [lr * scale for lr in original]
|
|
||||||
return original * scale
|
|
||||||
|
|||||||
Reference in New Issue
Block a user