KD Trainer w logprobs (#2303)
* refactor trainer to prevent circular dependencies later fix loader default KD dataset loading and KD with logprobs filter bad rows make batch smaller handle padding/collation for KD datasets make it work flipped the slice cross entropy loss coefficient during KD make sure to multiply against the correct loss chore: lint triton wip no where support v2 trial no torch.exp inside triton kernel no log etc no torch.tensor v3 fix kwarg don't use triton for now better rescaling for temperatures hash for temperature too use kd_alpha in the correct loss method fix kd loss so it's causal (fixes repeating tokens) var naming and add todo chore: lint refactor so we can easily add new loss functions add license block remove references to triton kd for now handle token/logprob shifting support for custom trainer classes from plugins refactor kd chat template loader move more things to kd plugin remove moved class from import make plugin setup concise increase logging around loading plugins add copyrights remove duplicate code more info on preprocess for kd and fix import be a bit pickier about loading dynamic prompt strategies kd sample packing make loss torch script compat support streaming for processing sft datasts? improve iterable support ensure that batch vs single is done properly tweak check for batched prompt data reward can use same batch check fix reward trainer calls for tokenization improve check for batched reward model doesn't work well with batched add kd trainer e2e test linting rename test files so it gets picked up make the kd e2e fit in vram for ci and add lora version set lora_dropout explicitly lower lr make sure to set tokenizer from l3 70b and save safetensors make sure to use the correct tokenizer fix adapter model check make sure to use tensorboard to capture loss for checks chore: lint chore: lint improve logprob masking and shift in trainer more fixes try tests for kd on l40s don't shift student logits for kd no batching for kd chat templates make sure to truncate logprobs if there are more than top_k change up logic so we always truncate to top_k use iter instead of tuple fix finding the top-k rather than assuming first position has the correct val apply z-score scaling to kd kd loss needs to be calculated in full precision Always re-normalize teacher distribution various fixes * support for configurable top-k/softmax ordering * add attribute check for filter rows and lint * fix logic * handle none case for conversion to int * fix student logit off by one * set kd_temp to 1.0 for test loss * address PR feedback
This commit is contained in:
@@ -3,11 +3,12 @@
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
Sequence,
|
||||
Value,
|
||||
concatenate_datasets,
|
||||
@@ -17,7 +18,7 @@ from datasets import (
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -59,7 +60,7 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_local_main_process()):
|
||||
@@ -70,6 +71,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="train",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer,
|
||||
@@ -77,6 +79,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
@@ -84,6 +87,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
# Load streaming dataset if pretraining_dataset is given
|
||||
@@ -139,6 +143,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
@@ -170,6 +175,7 @@ def load_tokenized_prepared_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
preprocess_iterable: Optional[bool] = None,
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
@@ -184,10 +190,11 @@ def load_tokenized_prepared_datasets(
|
||||
+ "@"
|
||||
+ str(cfg.group_by_length)
|
||||
+ "@"
|
||||
+ str(cfg.kd_temperature or 1.0)
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}"
|
||||
for d in cfg_datasets
|
||||
]
|
||||
)
|
||||
@@ -262,13 +269,25 @@ def load_tokenized_prepared_datasets(
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
streaming_ds = False
|
||||
if preprocess_iterable:
|
||||
streaming_ds = True
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token
|
||||
config_dataset, use_auth_token, streaming=streaming_ds
|
||||
)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
@@ -325,7 +344,21 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if isinstance(dataset, IterableDataset):
|
||||
|
||||
def gen_from_iter_ds(_ds, _=None):
|
||||
yield from _ds
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(gen_from_iter_ds, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=cfg.dataset_processes,
|
||||
split=split,
|
||||
gen_kwargs={"_": list(range(cfg.dataset_processes))},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
||||
@@ -345,6 +378,7 @@ def load_prepare_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
preprocess_iterable: Optional[bool] = False,
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer,
|
||||
@@ -352,6 +386,7 @@ def load_prepare_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split=split,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
@@ -451,7 +486,7 @@ def get_dataset_wrapper(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -464,7 +499,7 @@ def get_dataset_wrapper(
|
||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -487,7 +522,7 @@ def get_dataset_wrapper(
|
||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||
else:
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -500,7 +535,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -514,7 +549,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -528,7 +563,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -542,7 +577,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -556,7 +591,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -570,7 +605,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -584,7 +619,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -598,7 +633,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user