fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing * fix: logging import * fix: cfg passed being none * fix: try to fix logging * fix: refactor call to not use accelerate log * fix: try to fix circular import issue * fix: don't drop when skip prepare * chore: remove duplicate line * fix: update warning to mention that sequences will be trimmed * fix: do not drop seq if input_ids don't exist * fix: increase RM unittest sequence length to reduce trim warnings * fix: solve conflicts * fix: default min_seq_len in case of None
This commit is contained in:
@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||||
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
|||||||
from axolotl.utils.data.shared import load_dataset_w_config
|
from axolotl.utils.data.shared import load_dataset_w_config
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
|
drop_long_seq_in_dataset,
|
||||||
md5,
|
md5,
|
||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
else:
|
else:
|
||||||
LOG.debug("NOT shuffling merged datasets")
|
LOG.debug("NOT shuffling merged datasets")
|
||||||
|
|
||||||
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
||||||
|
|
||||||
|
if cfg.sample_packing:
|
||||||
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""data handling helpers"""
|
"""data handling helpers"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
@@ -6,10 +7,15 @@ import time
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import drop_long_seq
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetryStrategy(Enum):
|
class RetryStrategy(Enum):
|
||||||
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset, dataset
|
return train_dataset, eval_dataset, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||||
|
if "input_ids" not in dataset.column_names:
|
||||||
|
LOG.warning(
|
||||||
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
drop_long = functools.partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||||
|
LOG.debug(f"min_input_len: {min_input_len}")
|
||||||
|
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||||
|
LOG.debug(f"max_input_len: {max_input_len}")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
prior_len = len(dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
|
|
||||||
|
filter_map_kwargs = {}
|
||||||
|
if not isinstance(dataset, IterableDataset):
|
||||||
|
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||||
|
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
|
||||||
|
drop_long_kwargs = {}
|
||||||
|
if filter_map_kwargs:
|
||||||
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
|
|
||||||
|
dataset = dataset.filter(
|
||||||
|
drop_long,
|
||||||
|
batched=True,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
|
if prior_len:
|
||||||
|
dropped = prior_len - len(dataset)
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|||||||
@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
|
|||||||
else:
|
else:
|
||||||
input_ids = dataset.data.column("input_ids")
|
input_ids = dataset.data.column("input_ids")
|
||||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||||
return lengths
|
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
"""
|
"""
|
||||||
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
# Edge case: if input_ids is empty
|
# Edge case: if input_ids is empty
|
||||||
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(
|
|
||||||
drop_long_seq,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len or 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
filter_map_kwargs = {}
|
|
||||||
if not isinstance(train_dataset, IterableDataset):
|
|
||||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
|
||||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
|
||||||
|
|
||||||
try:
|
|
||||||
prior_len = len(train_dataset)
|
|
||||||
except TypeError:
|
|
||||||
# handle iterable datasets case
|
|
||||||
prior_len = None
|
|
||||||
drop_long_kwargs = {}
|
|
||||||
if filter_map_kwargs:
|
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
|
||||||
train_dataset = train_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
batched=True,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
if prior_len:
|
|
||||||
dropped = prior_len - len(train_dataset)
|
|
||||||
if dropped:
|
|
||||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
|
||||||
|
|
||||||
if eval_dataset:
|
|
||||||
try:
|
|
||||||
prior_len = len(eval_dataset)
|
|
||||||
except TypeError:
|
|
||||||
# handle iterable datasets case
|
|
||||||
prior_len = None
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
if prior_len:
|
|
||||||
dropped = prior_len - len(eval_dataset)
|
|
||||||
if dropped:
|
|
||||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
|
||||||
|
|
||||||
def drop_no_trainable_tokens(sample):
|
def drop_no_trainable_tokens(sample):
|
||||||
"""
|
"""
|
||||||
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
||||||
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
# handle iterable datasets case
|
# handle iterable datasets case
|
||||||
prior_len = None
|
prior_len = None
|
||||||
|
filter_map_kwargs = {}
|
||||||
|
if not isinstance(train_dataset, IterableDataset):
|
||||||
|
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||||
|
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
|||||||
"num_labels": 1,
|
"num_labels": 1,
|
||||||
"chat_template": "alpaca",
|
"chat_template": "alpaca",
|
||||||
"reward_model": True,
|
"reward_model": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 2048,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
|
|||||||
Reference in New Issue
Block a user