fix: drop long seq even if not sample packing (#2211)

* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
This commit is contained in:
NanoCode012
2025-02-04 21:43:35 +07:00
committed by GitHub
parent 158330ab60
commit a620d481e2
6 changed files with 76 additions and 63 deletions

View File

@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(chosen_tokenized["input_ids"]) > max_length: if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
) )
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(rejected_tokenized["input_ids"]) > max_length: if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
) )
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import ( from axolotl.utils.data.utils import (
deduplicate_and_log_datasets, deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5, md5,
retry_on_request_exceptions, retry_on_request_exceptions,
) )
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
process_datasets_for_packing, process_datasets_for_packing,
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
else: else:
LOG.debug("NOT shuffling merged datasets") LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset = drop_long_seq_in_dataset(dataset, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")

View File

@@ -1,4 +1,5 @@
"""data handling helpers""" """data handling helpers"""
import functools import functools
import hashlib import hashlib
import logging import logging
@@ -6,10 +7,15 @@ import time
from enum import Enum from enum import Enum
import huggingface_hub import huggingface_hub
import numpy as np
import requests import requests
from datasets import Dataset from datasets import Dataset, IterableDataset
LOG = logging.getLogger("axolotl") from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__)
class RetryStrategy(Enum): class RetryStrategy(Enum):
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
) )
return train_dataset, eval_dataset, dataset return train_dataset, eval_dataset, dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
else: else:
input_ids = dataset.data.column("input_ids") input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths return lengths

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import json import json
import math import math
import os import os
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
Works for both single-example (list[int]) or batched (list[list[int]]). Works for both single-example (list[int]) or batched (list[list[int]]).
""" """
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"] input_ids = sample["input_ids"]
# Edge case: if input_ids is empty # Edge case: if input_ids is empty
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
try:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
if cfg.model_config_type == "mamba": if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names: if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids") eval_dataset = eval_dataset.remove_columns("token_type_ids")
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_long,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
def drop_no_trainable_tokens(sample): def drop_no_trainable_tokens(sample):
""" """
Drop samples if all labels are -100 (i.e., zero trainable tokens). Drop samples if all labels are -100 (i.e., zero trainable tokens).
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
except TypeError: except TypeError:
# handle iterable datasets case # handle iterable datasets case
prior_len = None prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"

View File

@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"num_labels": 1, "num_labels": 1,
"chat_template": "alpaca", "chat_template": "alpaca",
"reward_model": True, "reward_model": True,
"sequence_len": 1024, "sequence_len": 2048,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,