diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index c6b0fe2cf..61191a47f 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): if len(chosen_tokenized["input_ids"]) > max_length: LOG.warning( - f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", + f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", ) chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] @@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): if len(rejected_tokenized["input_ids"]) > max_length: LOG.warning( - f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", + f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", ) rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 79bbb2972..722ad2de2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.shared import load_dataset_w_config from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, + drop_long_seq_in_dataset, md5, retry_on_request_exceptions, ) @@ -56,7 +57,7 @@ from axolotl.utils.trainer import ( process_datasets_for_packing, ) -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) @retry_on_request_exceptions(max_retries=3, delay=5) @@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets( else: LOG.debug("NOT shuffling merged datasets") - if cfg.sample_packing and not cfg.skip_prepare_dataset: - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + if not cfg.skip_prepare_dataset: + dataset = drop_long_seq_in_dataset(dataset, cfg) + + if cfg.sample_packing: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 657cbb77c..a6abd8d73 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -1,4 +1,5 @@ """data handling helpers""" + import functools import hashlib import logging @@ -6,10 +7,15 @@ import time from enum import Enum import huggingface_hub +import numpy as np import requests -from datasets import Dataset +from datasets import Dataset, IterableDataset -LOG = logging.getLogger("axolotl") +from axolotl.utils.dict import DictDefault +from axolotl.utils.samplers.utils import get_dataset_lengths +from axolotl.utils.trainer import drop_long_seq + +LOG = logging.getLogger(__name__) class RetryStrategy(Enum): @@ -150,3 +156,53 @@ def deduplicate_and_log_datasets( ) return train_dataset, eval_dataset, dataset + + +def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): + if "input_ids" not in dataset.column_names: + LOG.warning( + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + ) + return dataset + + drop_long = functools.partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len, + ) + + try: + min_input_len = np.min(get_dataset_lengths(dataset)) + LOG.debug(f"min_input_len: {min_input_len}") + max_input_len = np.max(get_dataset_lengths(dataset)) + LOG.debug(f"max_input_len: {max_input_len}") + except AttributeError: + pass + + try: + prior_len = len(dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + + filter_map_kwargs = {} + if not isinstance(dataset, IterableDataset): + filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess + + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Dropping Long Sequences" + + dataset = dataset.filter( + drop_long, + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, + ) + if prior_len: + dropped = prior_len - len(dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset") + + return dataset diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index e4af4e5f3..4e41c9b44 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -13,5 +13,4 @@ def get_dataset_lengths(dataset): else: input_ids = dataset.data.column("input_ids") lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) - return lengths return lengths diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e9bdfb4..61f03e7ad 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,4 +1,5 @@ """Module containing the Trainer class and related functions""" + import json import math import os @@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): Works for both single-example (list[int]) or batched (list[list[int]]). """ + min_sequence_len = min_sequence_len or 2 + input_ids = sample["input_ids"] # Edge case: if input_ids is empty @@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial( - drop_long_seq, - sequence_len=cfg.sequence_len, - min_sequence_len=cfg.min_sample_len or 2, - ) - - try: - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - except AttributeError: - pass - if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") @@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - filter_map_kwargs = {} - if not isinstance(train_dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_processes - filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess - - try: - prior_len = len(train_dataset) - except TypeError: - # handle iterable datasets case - prior_len = None - drop_long_kwargs = {} - if filter_map_kwargs: - drop_long_kwargs["desc"] = "Dropping Long Sequences" - train_dataset = train_dataset.filter( - drop_long, - batched=True, - **filter_map_kwargs, - **drop_long_kwargs, - ) - if prior_len: - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from train dataset") - - if eval_dataset: - try: - prior_len = len(eval_dataset) - except TypeError: - # handle iterable datasets case - prior_len = None - eval_dataset = eval_dataset.filter( - drop_long, - **filter_map_kwargs, - **drop_long_kwargs, - ) - if prior_len: - dropped = prior_len - len(eval_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from eval dataset") - def drop_no_trainable_tokens(sample): """ Drop samples if all labels are -100 (i.e., zero trainable tokens). @@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): except TypeError: # handle iterable datasets case prior_len = None + filter_map_kwargs = {} + if not isinstance(train_dataset, IterableDataset): + filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess + drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 7360a99dc..d123e6061 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): "num_labels": 1, "chat_template": "alpaca", "reward_model": True, - "sequence_len": 1024, + "sequence_len": 2048, "pad_to_sequence_len": True, "adapter": "lora", "lora_r": 8,