diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 55140f151..af1d51a46 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): dict(zip(feature_names, row)) ) for key, val in tokenized_prompt.items(): - for i in range(0, len(val), self.sequence_len): - res[key].append(val[i : i + self.sequence_len]) + res[key].append(val) # If there are no examples left, return an empty dictionary if not res: diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a6abd8d73..a8e19582e 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): ) try: - min_input_len = np.min(get_dataset_lengths(dataset)) - LOG.debug(f"min_input_len: {min_input_len}") - max_input_len = np.max(get_dataset_lengths(dataset)) - LOG.debug(f"max_input_len: {max_input_len}") + ds_lengths = get_dataset_lengths(dataset, from_arrow=True) + min_input_len = np.min(ds_lengths) + LOG.info(f"min_input_len: {min_input_len}") + max_input_len = np.max(ds_lengths) + LOG.info(f"max_input_len: {max_input_len}") except AttributeError: pass diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index 8f400c98d..09f1b081c 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -4,13 +4,17 @@ helper util to calculate dataset lengths import numpy as np -def get_dataset_lengths(dataset): +def get_dataset_lengths(dataset, from_arrow=False): if "length" in dataset.column_names: lengths = np.array(dataset["length"]) elif "position_ids" in dataset.column_names: position_ids = dataset["position_ids"] lengths = np.array([x[-1] + 1 for x in position_ids]) else: - input_ids = dataset["input_ids"] - lengths = np.array([len(seq) for seq in input_ids]) + if from_arrow: + input_ids = dataset.data.column("input_ids") + lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) + else: + input_ids = dataset["input_ids"] + lengths = np.array([len(seq) for seq in input_ids]) return lengths