bugfix for sample packing
This commit is contained in:
@@ -94,19 +94,10 @@ def wrap_dataset_for_tokenized_prompt(
|
|||||||
if prompt_tokenizer.supports_batched:
|
if prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|
||||||
def peek_and_get_columns():
|
|
||||||
# Create a fresh iterator just for peeking
|
|
||||||
temp_iter = iter(dataset)
|
|
||||||
first_example = next(temp_iter)
|
|
||||||
return list(first_example.keys())
|
|
||||||
|
|
||||||
original_columns = peek_and_get_columns()
|
|
||||||
|
|
||||||
# Map the dataset and remove original columns
|
# Map the dataset and remove original columns
|
||||||
# This ensures only tokenized columns remain
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
prompt_tokenizer.tokenize_prompt,
|
prompt_tokenizer.tokenize_prompt,
|
||||||
remove_columns=original_columns,
|
remove_columns=list(dataset.features.keys()),
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
|
||||||
from datasets import IterableDataset, disable_caching, enable_caching
|
from datasets import IterableDataset, disable_caching, enable_caching
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
@@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
|
||||||
|
"""
|
||||||
|
Create a filtered IterableDataset that works around a HuggingFace datasets
|
||||||
|
limitation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filtered_generator():
|
||||||
|
"""Generator that yields only samples that pass the filter function."""
|
||||||
|
if batched:
|
||||||
|
batch = []
|
||||||
|
batch_size = 1000 # Process in batches of 1000
|
||||||
|
|
||||||
|
for sample in dataset:
|
||||||
|
batch.append(sample)
|
||||||
|
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
# Create a batch dict from list of samples
|
||||||
|
batch_dict = {}
|
||||||
|
for key in batch[0].keys():
|
||||||
|
batch_dict[key] = [sample[key] for sample in batch]
|
||||||
|
|
||||||
|
# Apply filter function to batch
|
||||||
|
keep_mask = filter_fn(batch_dict)
|
||||||
|
|
||||||
|
# Yield samples that should be kept
|
||||||
|
for i, keep in enumerate(keep_mask):
|
||||||
|
if keep:
|
||||||
|
yield batch[i]
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
# Process remaining samples in batch
|
||||||
|
if batch:
|
||||||
|
batch_dict = {}
|
||||||
|
for key in batch[0].keys():
|
||||||
|
batch_dict[key] = [sample[key] for sample in batch]
|
||||||
|
|
||||||
|
keep_mask = filter_fn(batch_dict)
|
||||||
|
|
||||||
|
for i, keep in enumerate(keep_mask):
|
||||||
|
if keep:
|
||||||
|
yield batch[i]
|
||||||
|
else:
|
||||||
|
# For non-batched filtering, apply filter to each sample individually
|
||||||
|
for sample in dataset:
|
||||||
|
if filter_fn(sample):
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
# Create new IterableDataset from the filtered generator
|
||||||
|
filtered_dataset = IterableDataset.from_generator(filtered_generator)
|
||||||
|
|
||||||
|
# Preserve the original features if they exist
|
||||||
|
# pylint:disable=protected-access
|
||||||
|
if hasattr(dataset, "_info") and dataset._info.features is not None:
|
||||||
|
filtered_dataset._info.features = dataset._info.features
|
||||||
|
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def weighted_cross_entropy(
|
def weighted_cross_entropy(
|
||||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||||
@@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||||
train_dataset = train_dataset.filter(
|
|
||||||
drop_no_trainable_tokens,
|
# For IterableDatasets, always use custom filtering to avoid features issues
|
||||||
batched=True,
|
if isinstance(train_dataset, IterableDataset):
|
||||||
**filter_map_kwargs,
|
# IterableDatasets often have None features after transformations,
|
||||||
**drop_long_kwargs,
|
# so we use our custom filter implementation that doesn't rely on features
|
||||||
)
|
train_dataset = _create_filtered_iterable_dataset(
|
||||||
|
train_dataset, drop_no_trainable_tokens, batched=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_dataset = train_dataset.filter(
|
||||||
|
drop_no_trainable_tokens,
|
||||||
|
batched=True,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
if prior_len:
|
if prior_len:
|
||||||
dropped = prior_len - len(train_dataset)
|
dropped = prior_len - len(train_dataset)
|
||||||
if dropped:
|
if dropped:
|
||||||
|
|||||||
Reference in New Issue
Block a user