Deprecate max packed sequence len (#1141)
This commit is contained in:
@@ -642,10 +642,6 @@ sequence_len: 2048
|
|||||||
# Pad inputs so each step uses constant sized buffers
|
# Pad inputs so each step uses constant sized buffers
|
||||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
# Max sequence length to concatenate training samples together up to
|
|
||||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
|
||||||
# FutureWarning: This will soon be DEPRECATED
|
|
||||||
max_packed_sequence_len: 1024
|
|
||||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||||
sample_packing:
|
sample_packing:
|
||||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||||
|
|||||||
@@ -157,6 +157,9 @@ def normalize_config(cfg):
|
|||||||
if isinstance(cfg.learning_rate, str):
|
if isinstance(cfg.learning_rate, str):
|
||||||
cfg.learning_rate = float(cfg.learning_rate)
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|
||||||
|
if isinstance(cfg.pretraining_dataset, dict):
|
||||||
|
cfg.pretraining_dataset = [cfg.pretraining_dataset]
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
@@ -192,18 +195,8 @@ def validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||||
)
|
)
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
|
||||||
raise ValueError(
|
|
||||||
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
|
||||||
)
|
|
||||||
if cfg.max_packed_sequence_len:
|
if cfg.max_packed_sequence_len:
|
||||||
LOG.warning(
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
str(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from torch.utils.data import RandomSampler
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||||
@@ -71,9 +71,11 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
name = None
|
name = None
|
||||||
if isinstance(cfg.pretraining_dataset, dict):
|
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||||
path = cfg.pretraining_dataset["path"]
|
cfg.pretraining_dataset[0], dict
|
||||||
name = cfg.pretraining_dataset["name"]
|
):
|
||||||
|
path = cfg.pretraining_dataset[0]["path"]
|
||||||
|
name = cfg.pretraining_dataset[0]["name"]
|
||||||
|
|
||||||
train_dataset = load_pretraining_dataset(
|
train_dataset = load_pretraining_dataset(
|
||||||
path,
|
path,
|
||||||
@@ -88,11 +90,6 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
||||||
cfg, train_dataset, eval_dataset, tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
if total_eval_steps == 0:
|
if total_eval_steps == 0:
|
||||||
@@ -163,6 +160,10 @@ def load_tokenized_prepared_datasets(
|
|||||||
else:
|
else:
|
||||||
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||||
LOG.info("Loading raw datasets...")
|
LOG.info("Loading raw datasets...")
|
||||||
|
if not cfg.is_preprocess:
|
||||||
|
LOG.warning(
|
||||||
|
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.seed:
|
if cfg.seed:
|
||||||
seed = cfg.seed
|
seed = cfg.seed
|
||||||
@@ -382,6 +383,9 @@ def load_tokenized_prepared_datasets(
|
|||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
LOG.info("shuffle merged datasets")
|
LOG.info("shuffle merged datasets")
|
||||||
dataset = dataset.shuffle(seed=seed)
|
dataset = dataset.shuffle(seed=seed)
|
||||||
|
|
||||||
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
@@ -419,119 +423,9 @@ def load_prepare_datasets(
|
|||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||||
max_packed_sequence_len = (
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
max_packed_sequence_len = min(
|
|
||||||
max_packed_sequence_len, cfg.sequence_len
|
|
||||||
) # make sure we don't accidentally set it larger than sequence_len
|
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
|
||||||
prompters: List[Prompter] = []
|
|
||||||
if cfg.max_packed_sequence_len is not None:
|
|
||||||
# see if we can go ahead and load the stacked dataset
|
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
|
||||||
ds_hash = str(
|
|
||||||
md5(
|
|
||||||
(
|
|
||||||
str(cfg.sequence_len)
|
|
||||||
+ "@"
|
|
||||||
+ str(max_packed_sequence_len)
|
|
||||||
+ seed
|
|
||||||
+ "|".join(
|
|
||||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
|
||||||
)
|
|
||||||
+ "|"
|
|
||||||
+ tokenizer_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prepared_ds_path = (
|
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
|
||||||
if cfg.dataset_prepared_path
|
|
||||||
else Path(default_dataset_prepared_path) / ds_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = None
|
|
||||||
use_auth_token = cfg.hf_use_auth_token
|
|
||||||
try:
|
|
||||||
if cfg.push_dataset_to_hub:
|
|
||||||
LOG.info(
|
|
||||||
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
|
||||||
)
|
|
||||||
dataset = load_dataset(
|
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
|
||||||
token=use_auth_token,
|
|
||||||
)
|
|
||||||
dataset = dataset["train"]
|
|
||||||
except Exception: # pylint: disable=broad-except # nosec
|
|
||||||
pass
|
|
||||||
|
|
||||||
if dataset:
|
|
||||||
...
|
|
||||||
elif (
|
|
||||||
cfg.dataset_prepared_path
|
|
||||||
and any(prepared_ds_path.glob("*"))
|
|
||||||
and not cfg.is_preprocess
|
|
||||||
):
|
|
||||||
LOG.info(
|
|
||||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
|
||||||
)
|
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
|
||||||
LOG.info("Prepared packed dataset loaded from disk...")
|
|
||||||
if cfg.push_dataset_to_hub:
|
|
||||||
LOG.info(
|
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
|
||||||
)
|
|
||||||
dataset.push_to_hub(
|
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.seed:
|
|
||||||
dataset = dataset.shuffle(seed=cfg.seed)
|
|
||||||
|
|
||||||
constant_len_dataset = ConstantLengthDataset(
|
|
||||||
tokenizer,
|
|
||||||
[dataset],
|
|
||||||
seq_length=max_packed_sequence_len,
|
|
||||||
)
|
|
||||||
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
|
||||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
|
||||||
|
|
||||||
# filter out bad data
|
|
||||||
# TODO convert to dataset.filter(...)
|
|
||||||
dataset = Dataset.from_list(
|
|
||||||
[
|
|
||||||
d
|
|
||||||
for d in dataset
|
|
||||||
if len(d["input_ids"]) <= cfg.sequence_len
|
|
||||||
and len(d["input_ids"]) > 0
|
|
||||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
|
||||||
and len(d["input_ids"]) == len(d["labels"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
|
||||||
LOG.info(
|
|
||||||
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
|
||||||
)
|
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
|
||||||
if cfg.push_dataset_to_hub:
|
|
||||||
LOG.info(
|
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
|
||||||
)
|
|
||||||
dataset.push_to_hub(
|
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
|
||||||
private=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
@@ -877,6 +771,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
|
|||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
batched=True,
|
batched=True,
|
||||||
|
batch_size=10_000,
|
||||||
input_columns="text",
|
input_columns="text",
|
||||||
# remove all the existing columns after mapping since they end up having
|
# remove all the existing columns after mapping since they end up having
|
||||||
# a different length than the encoded/tokenized column
|
# a different length than the encoded/tokenized column
|
||||||
|
|||||||
@@ -329,11 +329,7 @@ def load_model(
|
|||||||
LOG.info("patching mixtral with flash attention")
|
LOG.info("patching mixtral with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if (
|
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||||
cfg.is_llama_derived_model
|
|
||||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
|
||||||
and not inference
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
|
|||||||
@@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
|||||||
return weighted_cross_entropy(logits, labels, weights)
|
return weighted_cross_entropy(logits, labels, weights)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_datasets_caching():
|
||||||
|
try:
|
||||||
|
set_caching_enabled(False)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
set_caching_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
def add_position_ids(sample):
|
def add_position_ids(sample):
|
||||||
sample_len = len(sample["input_ids"])
|
sample_len = len(sample["input_ids"])
|
||||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||||
@@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048):
|
|||||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disable_datasets_caching():
|
|
||||||
try:
|
|
||||||
set_caching_enabled(False)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
set_caching_enabled(True)
|
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
@@ -227,8 +227,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sampler=RandomSampler(train_dataset),
|
sampler=RandomSampler(train_dataset),
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_max_len=cfg.micro_batch_size
|
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
|
||||||
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -324,20 +324,19 @@ class ValidationTest(BaseValidation):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_packing(self):
|
def test_deprecated_packing(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"max_packed_sequence_len": 2048,
|
"max_packed_sequence_len": 1024,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with pytest.raises(
|
||||||
|
DeprecationWarning,
|
||||||
|
match=r"`max_packed_sequence_len` is no longer supported",
|
||||||
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert any(
|
|
||||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
|
||||||
in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def test_packing(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
@@ -352,16 +351,6 @@ class ValidationTest(BaseValidation):
|
|||||||
for record in self._caplog.records
|
for record in self._caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"max_packed_sequence_len": 2048,
|
|
||||||
"sample_packing": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
|
||||||
with pytest.raises(ValueError, match=regex_exp):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
is_torch_bf16_gpu_available(),
|
is_torch_bf16_gpu_available(),
|
||||||
reason="test should only run on gpus w/o bf16 support",
|
reason="test should only run on gpus w/o bf16 support",
|
||||||
|
|||||||
Reference in New Issue
Block a user