fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing * fix: logging import * fix: cfg passed being none * fix: try to fix logging * fix: refactor call to not use accelerate log * fix: try to fix circular import issue * fix: don't drop when skip prepare * chore: remove duplicate line * fix: update warning to mention that sequences will be trimmed * fix: do not drop seq if input_ids don't exist * fix: increase RM unittest sequence length to reduce trim warnings * fix: solve conflicts * fix: default min_seq_len in case of None
This commit is contained in:
@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||
|
||||
@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.shared import load_dataset_w_config
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
drop_long_seq_in_dataset,
|
||||
md5,
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
|
||||
process_datasets_for_packing,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
|
||||
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
if not cfg.skip_prepare_dataset:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
||||
|
||||
if cfg.sample_packing:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""data handling helpers"""
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
@@ -6,10 +7,15 @@ import time
|
||||
from enum import Enum
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import requests
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset, dataset
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
if "input_ids" not in dataset.column_names:
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||
)
|
||||
return dataset
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
prior_len = len(dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
return lengths
|
||||
return lengths
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Edge case: if input_ids is empty
|
||||
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if cfg.model_config_type == "mamba":
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(train_dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
try:
|
||||
prior_len = len(train_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
||||
|
||||
if eval_dataset:
|
||||
try:
|
||||
prior_len = len(eval_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_long,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
||||
|
||||
def drop_no_trainable_tokens(sample):
|
||||
"""
|
||||
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
||||
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(train_dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||
|
||||
Reference in New Issue
Block a user