Deprecate max packed sequence len (#1141)

This commit is contained in:
Wing Lian
2024-01-20 05:11:50 -05:00
committed by GitHub
parent 3db5f2fd17
commit 2ce5c0d68a
6 changed files with 38 additions and 170 deletions

View File

@@ -642,10 +642,6 @@ sequence_len: 2048
# Pad inputs so each step uses constant sized buffers # Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:
# Max sequence length to concatenate training samples together up to
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on. # Set to 'false' if getting errors during eval with sample_packing on.

View File

@@ -157,6 +157,9 @@ def normalize_config(cfg):
if isinstance(cfg.learning_rate, str): if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate) cfg.learning_rate = float(cfg.learning_rate)
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -192,18 +195,8 @@ def validate_config(cfg):
raise ValueError( raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
) )
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
)
if cfg.max_packed_sequence_len: if cfg.max_packed_sequence_len:
LOG.warning( raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
str(
PendingDeprecationWarning(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
)
)
)
if cfg.sample_packing and not cfg.pad_to_sequence_len: if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning( LOG.warning(

View File

@@ -19,7 +19,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
@@ -71,9 +71,11 @@ def prepare_dataset(cfg, tokenizer):
else: else:
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
name = None name = None
if isinstance(cfg.pretraining_dataset, dict): if isinstance(cfg.pretraining_dataset, list) and isinstance(
path = cfg.pretraining_dataset["path"] cfg.pretraining_dataset[0], dict
name = cfg.pretraining_dataset["name"] ):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
train_dataset = load_pretraining_dataset( train_dataset = load_pretraining_dataset(
path, path,
@@ -88,11 +90,6 @@ def prepare_dataset(cfg, tokenizer):
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters return train_dataset, eval_dataset, cfg.max_steps, prompters
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
)
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0: if total_eval_steps == 0:
@@ -163,6 +160,10 @@ def load_tokenized_prepared_datasets(
else: else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...") LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
)
if cfg.seed: if cfg.seed:
seed = cfg.seed seed = cfg.seed
@@ -382,6 +383,9 @@ def load_tokenized_prepared_datasets(
if len(datasets) > 1: if len(datasets) > 1:
LOG.info("shuffle merged datasets") LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed) dataset = dataset.shuffle(seed=seed)
dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
@@ -419,119 +423,9 @@ def load_prepare_datasets(
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Prompter]]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = ( dataset, prompters = load_tokenized_prepared_datasets(
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len tokenizer, cfg, default_dataset_prepared_path
) )
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|"
+ tokenizer_name
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(default_dataset_prepared_path) / ds_hash
)
dataset = None
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
pass
if dataset:
...
elif (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared packed dataset loaded from disk...")
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed)
constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data
# TODO convert to dataset.filter(...)
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) <= cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
if cfg.local_rank == 0:
LOG.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
LOG.info( LOG.info(
@@ -877,6 +771,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
dataset = dataset.map( dataset = dataset.map(
encode, encode,
batched=True, batched=True,
batch_size=10_000,
input_columns="text", input_columns="text",
# remove all the existing columns after mapping since they end up having # remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column # a different length than the encoded/tokenized column

View File

@@ -329,11 +329,7 @@ def load_model(
LOG.info("patching mixtral with flash attention") LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn() replace_mixtral_attn_with_multipack_flash_attn()
if ( if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not inference
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask") LOG.info("patching _expand_mask")

View File

@@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights) return weighted_cross_entropy(logits, labels, weights)
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def add_position_ids(sample): def add_position_ids(sample):
sample_len = len(sample["input_ids"]) sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"])) sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()): with zero_first(is_main_process()):
@@ -227,8 +227,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sampler=RandomSampler(train_dataset), sampler=RandomSampler(train_dataset),
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,
drop_last=True, drop_last=True,
batch_max_len=cfg.micro_batch_size batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
* (cfg.max_packed_sequence_len or cfg.sequence_len),
lengths=get_dataset_lengths(train_dataset), lengths=get_dataset_lengths(train_dataset),
) )

View File

@@ -324,20 +324,19 @@ class ValidationTest(BaseValidation):
validate_config(cfg) validate_config(cfg)
def test_packing(self): def test_deprecated_packing(self):
cfg = DictDefault( cfg = DictDefault(
{ {
"max_packed_sequence_len": 2048, "max_packed_sequence_len": 1024,
} }
) )
with self._caplog.at_level(logging.WARNING): with pytest.raises(
DeprecationWarning,
match=r"`max_packed_sequence_len` is no longer supported",
):
validate_config(cfg) validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)
def test_packing(self):
cfg = DictDefault( cfg = DictDefault(
{ {
"sample_packing": True, "sample_packing": True,
@@ -352,16 +351,6 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records for record in self._caplog.records
) )
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
@pytest.mark.skipif( @pytest.mark.skipif(
is_torch_bf16_gpu_available(), is_torch_bf16_gpu_available(),
reason="test should only run on gpus w/o bf16 support", reason="test should only run on gpus w/o bf16 support",