remove unused
This commit is contained in:
@@ -1,61 +0,0 @@
|
|||||||
# Example configuration for streaming pretraining
|
|
||||||
# This demonstrates how to pretrain on large datasets that don't fit in memory
|
|
||||||
|
|
||||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
|
|
||||||
# Required: max_steps for streaming pretraining
|
|
||||||
max_steps: 10000
|
|
||||||
|
|
||||||
# Pretraining dataset configuration
|
|
||||||
# These are automatically streamed
|
|
||||||
pretraining_dataset:
|
|
||||||
- path: allenai/c4
|
|
||||||
name: en
|
|
||||||
type: pretrain
|
|
||||||
# Optional: skip N samples (useful for resuming)
|
|
||||||
# skip: 1000000
|
|
||||||
|
|
||||||
# Can also use multiple pretraining datasets
|
|
||||||
# pretraining_dataset:
|
|
||||||
# - path: allenai/c4
|
|
||||||
# name: en
|
|
||||||
# type: pretrain
|
|
||||||
# - path: HuggingFaceFW/fineweb
|
|
||||||
# type: pretrain
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
|
|
||||||
# Sequence and packing configuration
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pretrain_multipack_attn: true
|
|
||||||
pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching
|
|
||||||
|
|
||||||
# Training hyperparameters
|
|
||||||
gradient_accumulation_steps: 8
|
|
||||||
micro_batch_size: 4
|
|
||||||
optimizer: adamw_torch
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 3e-4
|
|
||||||
|
|
||||||
# Memory optimizations
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
gradient_checkpointing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
# Checkpointing and logging
|
|
||||||
output_dir: ./outputs/pretrain-streaming
|
|
||||||
logging_steps: 10
|
|
||||||
save_steps: 500
|
|
||||||
save_total_limit: 3 # Keep only last 3 checkpoints
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
|
|
||||||
# Optional: enable wandb for monitoring
|
|
||||||
# wandb_project: streaming-pretrain
|
|
||||||
# wandb_entity: your-entity
|
|
||||||
# wandb_name: c4-pretrain
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
# Example configuration for streaming SFT training
|
|
||||||
# This enables training on datasets larger than memory by streaming them from HuggingFace Hub
|
|
||||||
|
|
||||||
base_model: HuggingFaceTB/SmolLM2-135M
|
|
||||||
|
|
||||||
# Enable streaming mode for datasets
|
|
||||||
streaming: true
|
|
||||||
|
|
||||||
# When using streaming, max_steps is required
|
|
||||||
max_steps: 3 # Just test a few steps
|
|
||||||
|
|
||||||
# Training datasets - these will be streamed
|
|
||||||
# datasets:
|
|
||||||
# - path: tatsu-lab/alpaca
|
|
||||||
# type: alpaca
|
|
||||||
# split: train
|
|
||||||
|
|
||||||
pretraining_dataset:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
split: train
|
|
||||||
|
|
||||||
# Dataset configuration
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true # Enable multipack batching
|
|
||||||
pretrain_multipack_attn: true # Enable multipack attention masking
|
|
||||||
pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT)
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|endoftext|>
|
|
||||||
|
|
||||||
# Training hyperparameters
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples
|
|
||||||
num_epochs: 1 # With streaming, typically use max_steps instead
|
|
||||||
optimizer: adamw_torch
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
# Enable efficient training
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
gradient_checkpointing: true
|
|
||||||
flash_attention: true # Enable flash attention with multipack
|
|
||||||
|
|
||||||
# Logging and checkpointing
|
|
||||||
logging_steps: 10
|
|
||||||
eval_steps: 100
|
|
||||||
save_steps: 200
|
|
||||||
output_dir: ./outputs/streaming-model
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
warmup_steps: 100
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
"""Streaming dataset with multipack support for SFT."""
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|
||||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_packed_sft_streaming(
|
|
||||||
collate_fn,
|
|
||||||
ds_wrapper_fn,
|
|
||||||
examples: Dict[str, List],
|
|
||||||
dataset_config,
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
d_base_type: str,
|
|
||||||
d_prompt_style: str | None,
|
|
||||||
processor: Any | None,
|
|
||||||
max_seq_length: int = 2048,
|
|
||||||
batch_size: int = 4,
|
|
||||||
multipack_attn: Optional[bool] = True,
|
|
||||||
) -> Dict[str, List]:
|
|
||||||
"""
|
|
||||||
Encode and pack streaming SFT data similar to how pretraining does it.
|
|
||||||
|
|
||||||
This function:
|
|
||||||
1. Tokenizes the examples using the dataset wrapper
|
|
||||||
2. Adds position_ids for each sequence
|
|
||||||
3. Uses MultipackBatchSampler to efficiently pack sequences
|
|
||||||
4. Applies the collator to handle attention masks properly
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collate_fn: Collator function for handling batches
|
|
||||||
ds_wrapper_fn: Function to get the dataset wrapper for tokenization
|
|
||||||
examples: Dict of lists containing the raw examples
|
|
||||||
dataset_config: Configuration for the dataset
|
|
||||||
tokenizer: Tokenizer to use
|
|
||||||
cfg: Main configuration
|
|
||||||
d_base_type: Dataset base type
|
|
||||||
d_prompt_style: Prompt style
|
|
||||||
processor: Optional processor for multimodal
|
|
||||||
max_seq_length: Maximum sequence length
|
|
||||||
batch_size: Batch size for packing
|
|
||||||
multipack_attn: Whether to use multipack attention
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict of packed and processed data ready for training
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
|
||||||
|
|
||||||
# Convert examples to Dataset for processing
|
|
||||||
temp_dataset = Dataset.from_dict(examples)
|
|
||||||
|
|
||||||
# Apply the dataset wrapper to tokenize
|
|
||||||
train_dataset, _ = get_dataset_wrapper(
|
|
||||||
dataset_config=dataset_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cfg=cfg,
|
|
||||||
dataset_base_type=d_base_type,
|
|
||||||
dataset=temp_dataset,
|
|
||||||
dataset_prompt_style=d_prompt_style,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process for packing - add position_ids and filter long sequences
|
|
||||||
train_dataset = process_pretraining_datasets_for_packing(
|
|
||||||
train_dataset,
|
|
||||||
max_seq_length,
|
|
||||||
skip_position_ids=not multipack_attn,
|
|
||||||
drop_attention_mask=multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use MultipackBatchSampler to create efficient packed batches
|
|
||||||
sampler = MultipackBatchSampler(
|
|
||||||
sampler=RandomSampler(train_dataset),
|
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
|
||||||
batch_size=1, # We pack multiple sequences into one "batch"
|
|
||||||
batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch
|
|
||||||
drop_last=True,
|
|
||||||
num_processes=1, # Single process for streaming
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect packed data
|
|
||||||
chunked_data = defaultdict(list)
|
|
||||||
|
|
||||||
for batch in sampler:
|
|
||||||
for data in batch:
|
|
||||||
features = train_dataset[data]
|
|
||||||
|
|
||||||
# Clean up unnecessary fields
|
|
||||||
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
|
|
||||||
if field in features:
|
|
||||||
del features[field]
|
|
||||||
|
|
||||||
# Ensure labels exist
|
|
||||||
if "labels" not in features:
|
|
||||||
features["labels"] = features["input_ids"].copy()
|
|
||||||
|
|
||||||
# Apply collator to handle padding and attention masks
|
|
||||||
collated_features = collate_fn(features)
|
|
||||||
|
|
||||||
# Collect features
|
|
||||||
for feature in features.keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
|
||||||
|
|
||||||
return chunked_data
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_streaming_sft_dataset_with_packing(
|
|
||||||
dataset: IterableDataset,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
cfg,
|
|
||||||
dataset_config,
|
|
||||||
d_base_type: str,
|
|
||||||
d_prompt_style: str | None,
|
|
||||||
processor: Any | None,
|
|
||||||
max_tokens: int = 2048,
|
|
||||||
buffer_size: int = 10_000,
|
|
||||||
) -> IterableDataset:
|
|
||||||
"""
|
|
||||||
Wrap a streaming SFT dataset with tokenization and multipack batching.
|
|
||||||
|
|
||||||
This creates properly packed batches with:
|
|
||||||
- Multiple sequences concatenated together
|
|
||||||
- Position IDs that reset for each sequence
|
|
||||||
- Attention masks that prevent cross-attention between sequences
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: The streaming dataset to wrap
|
|
||||||
tokenizer: Tokenizer to use
|
|
||||||
cfg: Configuration object
|
|
||||||
dataset_config: Dataset configuration
|
|
||||||
d_base_type: Base dataset type
|
|
||||||
d_prompt_style: Prompt style
|
|
||||||
processor: Optional processor for multimodal
|
|
||||||
max_tokens: Maximum sequence length
|
|
||||||
buffer_size: Buffer size for streaming/shuffling
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Wrapped streaming dataset with multipack batching
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Apply shuffling if configured
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
|
||||||
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
|
||||||
|
|
||||||
# Get column names from first sample
|
|
||||||
remove_columns = []
|
|
||||||
for first_row in dataset:
|
|
||||||
remove_columns = list(first_row.keys())
|
|
||||||
break
|
|
||||||
|
|
||||||
# Reset dataset after peeking
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
|
||||||
|
|
||||||
# Create the collator for multipack
|
|
||||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
pad_to_multiple_of=max_tokens,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the encoding function
|
|
||||||
# batch_size here refers to how many sequences to pack together to fill max_tokens
|
|
||||||
# The actual batching happens at the DataLoader level with micro_batch_size=1
|
|
||||||
pack_batch_size = max(
|
|
||||||
1, max_tokens // 512
|
|
||||||
) # Estimate based on typical sequence lengths
|
|
||||||
|
|
||||||
encode_fn = functools.partial(
|
|
||||||
encode_packed_sft_streaming,
|
|
||||||
collate_fn,
|
|
||||||
None, # ds_wrapper_fn will be created inside
|
|
||||||
dataset_config=dataset_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cfg=cfg,
|
|
||||||
d_base_type=d_base_type,
|
|
||||||
d_prompt_style=d_prompt_style,
|
|
||||||
processor=processor,
|
|
||||||
max_seq_length=max_tokens,
|
|
||||||
batch_size=pack_batch_size,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map the encoding function over the streaming dataset
|
|
||||||
# This will process data in batches and apply packing
|
|
||||||
dataset = dataset.map(
|
|
||||||
encode_fn,
|
|
||||||
batched=True,
|
|
||||||
batch_size=buffer_size, # Process large batches for efficiency
|
|
||||||
remove_columns=remove_columns,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set format for PyTorch
|
|
||||||
dataset = dataset.with_format("torch")
|
|
||||||
|
|
||||||
# IMPORTANT: Set micro_batch_size to 1 since we've already packed
|
|
||||||
# This prevents the trainer from trying to batch our packed sequences
|
|
||||||
cfg.micro_batch_size = 1
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
"""Simple streaming SFT with multipack support."""
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|
||||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingMultipackDataset:
|
|
||||||
"""Dataset that handles streaming with multipack on-the-fly."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_dataset: IterableDataset,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
cfg,
|
|
||||||
dataset_config,
|
|
||||||
d_base_type: str,
|
|
||||||
d_prompt_style: str | None,
|
|
||||||
processor: Any | None,
|
|
||||||
max_tokens: int = 2048,
|
|
||||||
pack_length: int = 4, # How many sequences to collect before packing
|
|
||||||
):
|
|
||||||
self.base_dataset = base_dataset
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.cfg = cfg
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.pack_length = pack_length
|
|
||||||
|
|
||||||
# Create the dataset wrapper once
|
|
||||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
|
||||||
|
|
||||||
# Create a dummy dataset to get the wrapper
|
|
||||||
dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]}
|
|
||||||
dummy_dataset = Dataset.from_dict(dummy_data)
|
|
||||||
|
|
||||||
self.dataset_wrapper, _ = get_dataset_wrapper(
|
|
||||||
dataset_config=dataset_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cfg=cfg,
|
|
||||||
dataset_base_type=d_base_type,
|
|
||||||
dataset=dummy_dataset,
|
|
||||||
dataset_prompt_style=d_prompt_style,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collator for packing
|
|
||||||
self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
pad_to_multiple_of=max_tokens,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""Iterator that yields packed samples."""
|
|
||||||
buffer = []
|
|
||||||
|
|
||||||
for sample in self.base_dataset:
|
|
||||||
# Convert single sample to dataset for processing
|
|
||||||
temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()})
|
|
||||||
|
|
||||||
# Tokenize using the dataset wrapper
|
|
||||||
try:
|
|
||||||
tokenized = self.dataset_wrapper.__class__(
|
|
||||||
temp_dataset,
|
|
||||||
**{
|
|
||||||
k: v
|
|
||||||
for k, v in self.dataset_wrapper.__dict__.items()
|
|
||||||
if not k.startswith("_")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the tokenized sample
|
|
||||||
if len(tokenized) > 0:
|
|
||||||
tokenized_sample = tokenized[0]
|
|
||||||
|
|
||||||
# Add to buffer
|
|
||||||
buffer.append(tokenized_sample)
|
|
||||||
|
|
||||||
# When buffer is full, pack and yield
|
|
||||||
if len(buffer) >= self.pack_length:
|
|
||||||
packed_sample = self._pack_buffer(buffer)
|
|
||||||
if packed_sample:
|
|
||||||
yield packed_sample
|
|
||||||
buffer = []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOG.warning(f"Failed to process sample: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process remaining buffer
|
|
||||||
if buffer:
|
|
||||||
packed_sample = self._pack_buffer(buffer)
|
|
||||||
if packed_sample:
|
|
||||||
yield packed_sample
|
|
||||||
|
|
||||||
def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]:
|
|
||||||
"""Pack a buffer of tokenized samples."""
|
|
||||||
if not buffer:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create dataset from buffer
|
|
||||||
temp_dataset = Dataset.from_list(buffer)
|
|
||||||
|
|
||||||
# Add position_ids and process for packing
|
|
||||||
temp_dataset = process_pretraining_datasets_for_packing(
|
|
||||||
temp_dataset,
|
|
||||||
self.max_tokens,
|
|
||||||
skip_position_ids=not self.cfg.pretrain_multipack_attn,
|
|
||||||
drop_attention_mask=self.cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use MultipackBatchSampler to create packed batches
|
|
||||||
sampler = MultipackBatchSampler(
|
|
||||||
sampler=RandomSampler(temp_dataset),
|
|
||||||
lengths=get_dataset_lengths(temp_dataset),
|
|
||||||
batch_size=1,
|
|
||||||
batch_max_len=self.max_tokens,
|
|
||||||
drop_last=False,
|
|
||||||
num_processes=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get packed data
|
|
||||||
for batch in sampler:
|
|
||||||
if batch and batch[0]: # Check if we have data
|
|
||||||
features = []
|
|
||||||
for idx in batch[0]: # batch[0] contains the indices
|
|
||||||
sample = temp_dataset[idx]
|
|
||||||
if "labels" not in sample:
|
|
||||||
sample["labels"] = sample["input_ids"].copy()
|
|
||||||
features.append(sample)
|
|
||||||
|
|
||||||
# Apply collator to create final packed sample
|
|
||||||
if features:
|
|
||||||
packed = self.collator(features)
|
|
||||||
return {
|
|
||||||
k: v.squeeze(0) if v.dim() > 1 else v
|
|
||||||
for k, v in packed.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOG.warning(f"Failed to pack buffer: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_streaming_sft_dataset_simple(
|
|
||||||
dataset: IterableDataset,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
cfg,
|
|
||||||
dataset_config,
|
|
||||||
d_base_type: str,
|
|
||||||
d_prompt_style: str | None,
|
|
||||||
processor: Any | None,
|
|
||||||
max_tokens: int = 2048,
|
|
||||||
buffer_size: int = 10_000,
|
|
||||||
) -> IterableDataset:
|
|
||||||
"""
|
|
||||||
Wrap a streaming SFT dataset with simple multipack batching.
|
|
||||||
|
|
||||||
This approach processes samples in small groups rather than large batches,
|
|
||||||
avoiding the repeated processing issue.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Apply shuffling if configured
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
|
||||||
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
|
||||||
|
|
||||||
# Create the streaming multipack dataset
|
|
||||||
multipack_dataset = StreamingMultipackDataset(
|
|
||||||
dataset,
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
dataset_config,
|
|
||||||
d_base_type,
|
|
||||||
d_prompt_style,
|
|
||||||
processor,
|
|
||||||
max_tokens,
|
|
||||||
pack_length=max(1, max_tokens // 512), # Estimate sequences per pack
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert back to IterableDataset
|
|
||||||
class IterableWrapper:
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.dataset)
|
|
||||||
|
|
||||||
wrapped = IterableWrapper(multipack_dataset)
|
|
||||||
|
|
||||||
# Set micro_batch_size to 1 since sequences are already packed
|
|
||||||
cfg.micro_batch_size = 1
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
Reference in New Issue
Block a user