Compare commits
9 Commits
streaming-
...
squash_pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ba1cd3f1 | ||
|
|
eea7a006e1 | ||
|
|
ab4d604a8f | ||
|
|
0fa752e58b | ||
|
|
08e517ea48 | ||
|
|
07fd22f39b | ||
|
|
06eaf6c448 | ||
|
|
050210e637 | ||
|
|
05cedbfb1e |
@@ -12,5 +12,6 @@ reviews:
|
|||||||
auto_review:
|
auto_review:
|
||||||
enabled: true
|
enabled: true
|
||||||
drafts: false
|
drafts: false
|
||||||
|
auto_incremental_review: true
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ model, and final model output, you may need at least 3TB of free disk space to k
|
|||||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
|
||||||
|
training of the 120B model using Baseten Truss. You can read more about this recipe on
|
||||||
|
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
|
||||||
|
be found on their
|
||||||
|
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
|
||||||
|
|
||||||
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||||
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||||
|
|
||||||
@@ -61,9 +67,23 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
|||||||
|
|
||||||
### Inferencing your fine-tuned model
|
### Inferencing your fine-tuned model
|
||||||
|
|
||||||
|
#### vLLM
|
||||||
|
|
||||||
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||||
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||||
|
|
||||||
|
Optionally, vLLM can be installed from nightly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||||
|
```
|
||||||
|
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
||||||
|
```bash
|
||||||
|
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### SGLang
|
||||||
|
|
||||||
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||||
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ bf16: true
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ bf16: true
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
field_thinking: thinking
|
field_thinking: thinking
|
||||||
template_thinking_key: thinking
|
template_thinking_key: thinking
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: ./outputs/last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ bf16: true
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
field_thinking: thinking
|
field_thinking: thinking
|
||||||
template_thinking_key: thinking
|
template_thinking_key: thinking
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: ./outputs/last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ bf16: true
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ bf16: true
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ liger-kernel==0.6.1
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.17.0
|
peft>=0.17.0
|
||||||
transformers==4.55.2
|
transformers==4.55.3
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.10.0
|
accelerate==1.10.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -118,9 +118,9 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.8.2"],
|
"flash-attn": ["flash-attn==2.8.3"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.8.2",
|
"flash-attn==2.8.3",
|
||||||
"ring-flash-attn>=0.1.7",
|
"ring-flash-attn>=0.1.7",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def get_image(self):
|
def get_image(self):
|
||||||
docker_tag = "main-py3.11-cu124-2.6.0"
|
docker_tag = "main-py3.11-cu126-2.7.1"
|
||||||
if self.config.docker_tag:
|
if self.config.docker_tag:
|
||||||
docker_tag = self.config.docker_tag
|
docker_tag = self.config.docker_tag
|
||||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||||
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
|
|||||||
if family in ["a10", "a10g"]:
|
if family in ["a10", "a10g"]:
|
||||||
return modal.gpu.A10G(count=count)
|
return modal.gpu.A10G(count=count)
|
||||||
if family == "h100":
|
if family == "h100":
|
||||||
return modal.gpu.H100(count=count)
|
return f"H100:{count}"
|
||||||
if family == "t4":
|
if family == "t4":
|
||||||
return modal.gpu.T4(count=count)
|
return modal.gpu.T4(count=count)
|
||||||
if family == "l4":
|
if family == "l4":
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template)
|
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||||
elif cfg.datasets[0].type == "chat_template":
|
elif cfg.datasets[0].type == "chat_template":
|
||||||
chat_template_str = get_chat_template_from_config(
|
chat_template_str = get_chat_template_from_config(
|
||||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||||
|
|||||||
@@ -97,7 +97,8 @@ def do_cli(
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
is_preprocess = kwargs.pop("is_preprocess", True)
|
||||||
|
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
|
||||||
parsed_cfg.is_preprocess = True
|
parsed_cfg.is_preprocess = True
|
||||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def generate_sweep_configs(
|
def generate_sweep_configs(
|
||||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||||
) -> list[dict[str, list]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
|
|||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
is_group = len(permutations) > 1
|
is_group = len(permutations) > 1
|
||||||
for permutation in permutations:
|
base_output_dir = base_config.get("output_dir", "./model-out")
|
||||||
|
for idx, permutation in enumerate(permutations, start=1):
|
||||||
|
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
|
||||||
|
permutation_id = f"sweep{idx:04d}"
|
||||||
|
permutation["output_dir"] = str(permutation_dir / permutation_id)
|
||||||
|
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
mode="w",
|
mode="w",
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from axolotl.utils.collators import (
|
|||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
StreamingDataCollator,
|
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
@@ -423,17 +422,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
is_eval=False,
|
is_eval=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from datasets import IterableDataset
|
|
||||||
|
|
||||||
if isinstance(self.train_dataset, IterableDataset) and not is_eval:
|
|
||||||
LOG.info("Using StreamingDataCollator")
|
|
||||||
return StreamingDataCollator(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
cfg=self.cfg,
|
|
||||||
prompter=None,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if (
|
||||||
self.cfg.pretraining_sample_concatenation is False
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
@@ -488,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
if self.cfg.squash_position_ids:
|
||||||
|
kwargs["squash_position_ids"] = True
|
||||||
else:
|
else:
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -43,11 +43,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
# For IterableDataset, we can't access features upfront
|
features = dataset.features.keys()
|
||||||
# We'll need to infer from the first batch
|
|
||||||
features = None
|
|
||||||
if hasattr(dataset, "features") and dataset.features:
|
|
||||||
features = dataset.features.keys()
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -58,29 +54,18 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||||
and self.prompt_tokenizer.filter_rows
|
and self.prompt_tokenizer.filter_rows
|
||||||
):
|
):
|
||||||
filter_kwargs = {"desc": "Strategy Filtering Rows"}
|
|
||||||
# Only add num_proc for regular datasets
|
|
||||||
if features is not None:
|
|
||||||
filter_kwargs["num_proc"] = self.process_count
|
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
**filter_kwargs,
|
num_proc=self.process_count,
|
||||||
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
|
|
||||||
map_kwargs = {
|
|
||||||
**map_kwargs,
|
|
||||||
"desc": "Tokenizing Prompts",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only add remove_columns for regular datasets
|
|
||||||
if features is not None:
|
|
||||||
map_kwargs["remove_columns"] = features
|
|
||||||
map_kwargs["num_proc"] = self.process_count
|
|
||||||
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
|
num_proc=self.process_count,
|
||||||
|
remove_columns=features,
|
||||||
|
keep_in_memory=self.keep_in_memory,
|
||||||
|
desc="Tokenizing Prompts",
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -277,6 +277,14 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||||
|
apply_multipack_dataloader_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Applying multipack dataloader patch for sample packing...")
|
||||||
|
apply_multipack_dataloader_patch()
|
||||||
|
|
||||||
def _apply_fsdp2_bnb_patches(self):
|
def _apply_fsdp2_bnb_patches(self):
|
||||||
"""Apply FSDP2 BNB patches."""
|
"""Apply FSDP2 BNB patches."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
|||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||||
if module.base_layer.bias is not None:
|
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
|
||||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||||
log_bias_dtype_mismatch = True
|
log_bias_dtype_mismatch = True
|
||||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
"""Monkey patches for the dataset fetcher to handle batches of packed indexes."""
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
@@ -6,10 +6,20 @@ import torch
|
|||||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||||
from torch.utils.data._utils.worker import _worker_loop
|
from torch.utils.data._utils.worker import _worker_loop
|
||||||
|
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER = None
|
||||||
|
_ORIGINAL_WORKER_LOOP = None
|
||||||
|
_IS_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||||
|
"""
|
||||||
|
Custom dataset fetcher that handles nested batch structures from
|
||||||
|
MultipackBatchSampler.
|
||||||
|
"""
|
||||||
|
|
||||||
def fetch(self, possibly_batched_index):
|
def fetch(self, possibly_batched_index):
|
||||||
if isinstance(possibly_batched_index[0], list):
|
if isinstance(possibly_batched_index[0], list):
|
||||||
|
# Handle nested structure from MultipackBatchSampler
|
||||||
data = [None for i in possibly_batched_index]
|
data = [None for i in possibly_batched_index]
|
||||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||||
if self.auto_collation:
|
if self.auto_collation:
|
||||||
@@ -23,6 +33,7 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|||||||
else:
|
else:
|
||||||
data[i] = self.dataset[possibly_batched_index_]
|
data[i] = self.dataset[possibly_batched_index_]
|
||||||
else:
|
else:
|
||||||
|
# Standard batch handling
|
||||||
if self.auto_collation:
|
if self.auto_collation:
|
||||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||||
data = self.dataset.__getitems__(possibly_batched_index)
|
data = self.dataset.__getitems__(possibly_batched_index)
|
||||||
@@ -34,14 +45,54 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|||||||
|
|
||||||
|
|
||||||
def patch_fetchers():
|
def patch_fetchers():
|
||||||
|
"""Apply patches to PyTorch's DataLoader components."""
|
||||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||||
|
|
||||||
|
|
||||||
def patched_worker_loop(*args, **kwargs):
|
def patched_worker_loop(*args, **kwargs):
|
||||||
|
"""Worker loop that ensures patches are applied in worker processes."""
|
||||||
patch_fetchers()
|
patch_fetchers()
|
||||||
return _worker_loop(*args, **kwargs)
|
return _worker_loop(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
def apply_multipack_dataloader_patch():
|
||||||
patch_fetchers()
|
"""
|
||||||
|
This patch allows DataLoader to correctly process batches that contain multiple bins
|
||||||
|
of packed sequences.
|
||||||
|
"""
|
||||||
|
# pylint: disable=global-statement
|
||||||
|
global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED
|
||||||
|
|
||||||
|
if _IS_PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store original implementations
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher
|
||||||
|
_ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop
|
||||||
|
|
||||||
|
# Apply patches
|
||||||
|
patch_fetchers()
|
||||||
|
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||||
|
|
||||||
|
_IS_PATCHED = True
|
||||||
|
|
||||||
|
|
||||||
|
def remove_multipack_dataloader_patch():
|
||||||
|
"""Remove the monkeypatch and restore original PyTorch DataLoader behavior."""
|
||||||
|
# pylint: disable=global-statement
|
||||||
|
global _IS_PATCHED
|
||||||
|
|
||||||
|
if not _IS_PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _ORIGINAL_MAP_DATASET_FETCHER:
|
||||||
|
torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER
|
||||||
|
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ORIGINAL_WORKER_LOOP:
|
||||||
|
torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
|
||||||
|
|
||||||
|
_IS_PATCHED = False
|
||||||
|
|||||||
@@ -253,7 +253,9 @@ def save_trained_model(
|
|||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
if ( # pylint: disable=too-many-nested-blocks
|
||||||
|
trainer.is_fsdp_enabled or cfg.fsdp_config
|
||||||
|
):
|
||||||
if cfg.fsdp_config or cfg.fsdp:
|
if cfg.fsdp_config or cfg.fsdp:
|
||||||
if cfg.fsdp_config.final_state_dict_type:
|
if cfg.fsdp_config.final_state_dict_type:
|
||||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||||
@@ -285,6 +287,8 @@ def save_trained_model(
|
|||||||
if trainer.accelerator.is_main_process:
|
if trainer.accelerator.is_main_process:
|
||||||
# move all files in merged_path to cfg.output_dir
|
# move all files in merged_path to cfg.output_dir
|
||||||
for merged_file in Path(merged_path).iterdir():
|
for merged_file in Path(merged_path).iterdir():
|
||||||
|
if (Path(cfg.output_dir) / merged_file.name).exists():
|
||||||
|
(Path(cfg.output_dir) / merged_file.name).unlink()
|
||||||
shutil.move(str(merged_file), cfg.output_dir)
|
shutil.move(str(merged_file), cfg.output_dir)
|
||||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||||
|
|||||||
@@ -1,19 +1,11 @@
|
|||||||
"""Shared axolotl collators for multipack, mamba, multimodal, etc."""
|
"""
|
||||||
|
shared axolotl collators for multipack, mamba, multimodal
|
||||||
|
"""
|
||||||
|
|
||||||
from .batching import (
|
from .batching import ( # noqa: F401
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from .mamba import MambaDataCollator
|
from .mamba import MambaDataCollator # noqa: F401
|
||||||
from .streaming import StreamingDataCollator
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BatchSamplerDataCollatorForSeq2Seq",
|
|
||||||
"DataCollatorForSeq2Seq",
|
|
||||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
|
||||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
|
||||||
"MambaDataCollator",
|
|
||||||
"StreamingDataCollator",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import PreTrainedTokenizerBase, default_data_collator
|
|
||||||
from transformers.utils import PaddingStrategy
|
|
||||||
|
|
||||||
from axolotl.prompters import Prompter
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamingDataCollator:
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
|
||||||
cfg: DictDefault
|
|
||||||
prompter: Prompter | None = None
|
|
||||||
padding: bool | str | PaddingStrategy = True
|
|
||||||
max_length: int | None = None
|
|
||||||
pad_to_multiple_of: int | None = None
|
|
||||||
label_pad_token_id: int = -100
|
|
||||||
return_tensors: str = "pt"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.max_length is None:
|
|
||||||
self.max_length = self.cfg.sequence_len
|
|
||||||
|
|
||||||
def __call__(self, raw_batch: List[dict]) -> dict[str, Any]:
|
|
||||||
processed_samples = []
|
|
||||||
|
|
||||||
for raw_sample in raw_batch:
|
|
||||||
formatted_sample = raw_sample
|
|
||||||
if self.prompter:
|
|
||||||
formatted_sample = self._apply_prompt_formatting(raw_sample)
|
|
||||||
|
|
||||||
tokenized_sample = self._tokenize_sample(formatted_sample)
|
|
||||||
|
|
||||||
if len(tokenized_sample["input_ids"]) > self.max_length:
|
|
||||||
tokenized_sample = self._truncate_sample(tokenized_sample)
|
|
||||||
|
|
||||||
if tokenized_sample.get("input_ids"):
|
|
||||||
processed_samples.append(tokenized_sample)
|
|
||||||
|
|
||||||
return self._pad_and_batch(processed_samples)
|
|
||||||
|
|
||||||
def _apply_prompt_formatting(self, raw_sample: dict) -> dict:
|
|
||||||
formatted_text = self.prompter.build_prompt(
|
|
||||||
instruction=raw_sample.get("instruction", ""),
|
|
||||||
input=raw_sample.get("input", ""),
|
|
||||||
output=raw_sample.get("output", ""),
|
|
||||||
)
|
|
||||||
return {"text": formatted_text}
|
|
||||||
|
|
||||||
def _tokenize_sample(self, sample: dict) -> dict:
|
|
||||||
text = sample.get("text", sample.get("content", ""))
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
instruction = sample.get("instruction", "")
|
|
||||||
input_text = sample.get("input", "")
|
|
||||||
output_text = sample.get("output", "")
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
if instruction:
|
|
||||||
parts.append(f"Instruction: {instruction}")
|
|
||||||
if input_text:
|
|
||||||
parts.append(f"Input: {input_text}")
|
|
||||||
if output_text:
|
|
||||||
parts.append(f"Output: {output_text}")
|
|
||||||
text = "\n".join(parts)
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
tokenized = self.tokenizer(
|
|
||||||
text,
|
|
||||||
truncation=False,
|
|
||||||
padding=False,
|
|
||||||
return_tensors=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenized["labels"] = tokenized["input_ids"].copy()
|
|
||||||
return tokenized
|
|
||||||
|
|
||||||
def _truncate_sample(self, tokenized_sample: dict) -> dict:
|
|
||||||
max_len = self.max_length
|
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
if key in tokenized_sample:
|
|
||||||
tokenized_sample[key] = tokenized_sample[key][:max_len]
|
|
||||||
return tokenized_sample
|
|
||||||
|
|
||||||
def _pad_and_batch(self, processed_samples: List[dict]) -> dict[str, Any]:
|
|
||||||
if not processed_samples:
|
|
||||||
processed_samples = [
|
|
||||||
{
|
|
||||||
"input_ids": [self.tokenizer.eos_token_id],
|
|
||||||
"attention_mask": [1],
|
|
||||||
"labels": [self.tokenizer.eos_token_id],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
batch_samples = []
|
|
||||||
for sample in processed_samples:
|
|
||||||
batch_sample = {}
|
|
||||||
for key, value in sample.items():
|
|
||||||
if key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
batch_sample[key] = torch.tensor(value, dtype=torch.long)
|
|
||||||
batch_samples.append(batch_sample)
|
|
||||||
|
|
||||||
if self.padding:
|
|
||||||
max_len_in_batch = max(len(sample["input_ids"]) for sample in batch_samples)
|
|
||||||
|
|
||||||
for sample in batch_samples:
|
|
||||||
current_len = len(sample["input_ids"])
|
|
||||||
pad_len = max_len_in_batch - current_len
|
|
||||||
|
|
||||||
if pad_len > 0:
|
|
||||||
pad_token_id = (
|
|
||||||
self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
sample["input_ids"] = torch.cat(
|
|
||||||
[
|
|
||||||
sample["input_ids"],
|
|
||||||
torch.full((pad_len,), pad_token_id, dtype=torch.long),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
sample["attention_mask"] = torch.cat(
|
|
||||||
[
|
|
||||||
sample["attention_mask"],
|
|
||||||
torch.zeros(pad_len, dtype=torch.long),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
sample["labels"] = torch.cat(
|
|
||||||
[
|
|
||||||
sample["labels"],
|
|
||||||
torch.full(
|
|
||||||
(pad_len,), self.label_pad_token_id, dtype=torch.long
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
batch = {}
|
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
if key in batch_samples[0]:
|
|
||||||
batch[key] = torch.stack([sample[key] for sample in batch_samples])
|
|
||||||
|
|
||||||
return batch
|
|
||||||
@@ -9,7 +9,6 @@ from datasets import (
|
|||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
IterableDataset,
|
IterableDataset,
|
||||||
IterableDatasetDict,
|
|
||||||
load_dataset,
|
load_dataset,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
@@ -44,18 +43,6 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _determine_streaming_mode(cfg: DictDefault) -> bool:
|
|
||||||
"""Determine if we should use streaming mode based on config."""
|
|
||||||
if cfg.streaming is not None:
|
|
||||||
return cfg.streaming
|
|
||||||
|
|
||||||
# Default to streaming for pretraining datasets
|
|
||||||
if cfg.pretraining_dataset:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_datasets(
|
def prepare_datasets(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -74,52 +61,11 @@ def prepare_datasets(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
||||||
"""
|
"""
|
||||||
streaming_mode = _determine_streaming_mode(cfg)
|
if cfg.pretraining_dataset:
|
||||||
|
return _prepare_pretraining_dataset(
|
||||||
if streaming_mode:
|
cfg, tokenizer, processor, preprocess_iterable
|
||||||
if cfg.pretraining_dataset:
|
|
||||||
return _prepare_streaming_pretraining_dataset(cfg, tokenizer, processor)
|
|
||||||
else:
|
|
||||||
return _prepare_streaming_sft_dataset(cfg, tokenizer, processor)
|
|
||||||
else:
|
|
||||||
if cfg.pretraining_dataset:
|
|
||||||
return _prepare_pretraining_dataset(
|
|
||||||
cfg, tokenizer, processor, preprocess_iterable=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return _prepare_standard_dataset(
|
|
||||||
cfg, tokenizer, processor, preprocess_iterable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_streaming_sft_dataset(
|
|
||||||
cfg: DictDefault,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
processor: ProcessorMixin | None,
|
|
||||||
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
|
||||||
LOG.info("Loading streaming datasets")
|
|
||||||
|
|
||||||
raw_datasets = _load_raw_datasets_for_streaming(cfg, split="train")
|
|
||||||
|
|
||||||
eval_dataset = None
|
|
||||||
if cfg.test_datasets:
|
|
||||||
eval_raw_datasets = _load_raw_datasets_for_streaming(
|
|
||||||
cfg, split="test", dataset_configs=cfg.test_datasets
|
|
||||||
)
|
)
|
||||||
eval_dataset = _process_eval_dataset_minimal(
|
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
|
||||||
eval_raw_datasets, cfg, tokenizer, processor
|
|
||||||
)
|
|
||||||
elif cfg.val_set_size:
|
|
||||||
LOG.info("Validation splits not supported for streaming datasets")
|
|
||||||
|
|
||||||
if not cfg.max_steps:
|
|
||||||
raise ValueError("max_steps must be set when using streaming datasets")
|
|
||||||
|
|
||||||
total_num_steps = cfg.max_steps
|
|
||||||
LOG.info(f"Maximum steps: {total_num_steps}")
|
|
||||||
|
|
||||||
prompters = [None] * len(cfg.datasets) if cfg.datasets else []
|
|
||||||
return raw_datasets, eval_dataset, total_num_steps, prompters
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_standard_dataset(
|
def _prepare_standard_dataset(
|
||||||
@@ -427,7 +373,7 @@ def _load_and_process_single_dataset(
|
|||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|
||||||
# Select the appropriate split
|
# Select the appropriate split
|
||||||
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
|
if isinstance(dataset, DatasetDict):
|
||||||
if dataset_config.split and dataset_config.split in dataset:
|
if dataset_config.split and dataset_config.split in dataset:
|
||||||
dataset = dataset[dataset_config.split]
|
dataset = dataset[dataset_config.split]
|
||||||
elif split in dataset:
|
elif split in dataset:
|
||||||
@@ -566,78 +512,3 @@ def _load_and_prepare_datasets(
|
|||||||
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
|
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
|
||||||
|
|
||||||
return train_dataset, eval_dataset, prompters
|
return train_dataset, eval_dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
def _load_raw_datasets_for_streaming(
|
|
||||||
cfg: DictDefault, split: str = "train", dataset_configs: list | None = None
|
|
||||||
) -> IterableDataset:
|
|
||||||
configs = (
|
|
||||||
dataset_configs
|
|
||||||
if dataset_configs is not None
|
|
||||||
else (cfg.datasets if split == "train" else cfg.test_datasets)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not configs:
|
|
||||||
raise ValueError(f"No dataset configurations found for split '{split}'")
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
for dataset_config in datasets_with_name_generator(configs):
|
|
||||||
raw_dataset = load_dataset_with_config(
|
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(raw_dataset, (DatasetDict, IterableDatasetDict)):
|
|
||||||
if dataset_config.split and dataset_config.split in raw_dataset:
|
|
||||||
raw_dataset = raw_dataset[dataset_config.split]
|
|
||||||
elif split in raw_dataset:
|
|
||||||
raw_dataset = raw_dataset[split]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"no {split} split found for dataset {dataset_config.path}, "
|
|
||||||
"you may specify a split with 'split: ...'"
|
|
||||||
)
|
|
||||||
|
|
||||||
datasets.append(raw_dataset)
|
|
||||||
|
|
||||||
if len(datasets) == 1:
|
|
||||||
return datasets[0]
|
|
||||||
else:
|
|
||||||
return merge_datasets(datasets, cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_eval_dataset_minimal(
|
|
||||||
raw_dataset: IterableDataset,
|
|
||||||
cfg: DictDefault,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
processor: ProcessorMixin | None,
|
|
||||||
) -> Dataset | None:
|
|
||||||
LOG.info("Eval dataset processing skipped for streaming")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_streaming_pretraining_dataset(
|
|
||||||
cfg: DictDefault,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
processor: ProcessorMixin | None,
|
|
||||||
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
|
||||||
pretraining_config = _extract_pretraining_config(cfg)
|
|
||||||
|
|
||||||
train_dataset = load_dataset_with_config(
|
|
||||||
pretraining_config, cfg.hf_use_auth_token, streaming=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(train_dataset, (DatasetDict, IterableDatasetDict)):
|
|
||||||
if pretraining_config.split and pretraining_config.split in train_dataset:
|
|
||||||
train_dataset = train_dataset[pretraining_config.split]
|
|
||||||
elif "train" in train_dataset:
|
|
||||||
train_dataset = train_dataset["train"]
|
|
||||||
else:
|
|
||||||
raise ValueError("no train split found for pretraining dataset")
|
|
||||||
|
|
||||||
if not cfg.max_steps:
|
|
||||||
raise ValueError("max_steps must be set when using streaming datasets")
|
|
||||||
|
|
||||||
total_num_steps = cfg.max_steps
|
|
||||||
LOG.info(f"Maximum steps: {total_num_steps}")
|
|
||||||
|
|
||||||
return train_dataset, None, total_num_steps, []
|
|
||||||
|
|||||||
@@ -190,18 +190,12 @@ def handle_long_seq_in_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences removed.
|
||||||
"""
|
"""
|
||||||
if hasattr(dataset, "column_names") and dataset.column_names:
|
if "input_ids" not in dataset.column_names:
|
||||||
if "input_ids" not in dataset.column_names:
|
LOG.warning(
|
||||||
LOG.warning(
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
"expected for reward modeling."
|
||||||
"expected for reward modeling."
|
)
|
||||||
)
|
return dataset
|
||||||
return dataset
|
|
||||||
else:
|
|
||||||
# For IterableDataset, we can't check columns upfront, so skip for streaming
|
|
||||||
if isinstance(dataset, IterableDataset):
|
|
||||||
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
|
|||||||
@@ -459,6 +459,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
squash_position_ids: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -932,34 +938,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
fix_untrained_tokens: int | list[int] | None = None
|
fix_untrained_tokens: int | list[int] | None = None
|
||||||
|
|
||||||
streaming: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
streaming_dataset_mixing_strategy: str | None = Field(
|
|
||||||
default="round_robin",
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
streaming_mixing_weights: list[float] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
streaming_buffer_per_dataset: int | None = Field(
|
|
||||||
default=1000,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: bool | None = None
|
is_preprocess: bool | None = None
|
||||||
preprocess_iterable: bool | None = None
|
preprocess_iterable: bool | None = None
|
||||||
|
|||||||
@@ -1337,30 +1337,6 @@ class GRPOVllmValidationMixin:
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
class StreamingValidationMixin:
|
|
||||||
"""Validation methods related to streaming datasets."""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_streaming_requires_max_steps(self):
|
|
||||||
"""Ensure max_steps is set when using streaming datasets."""
|
|
||||||
# Check if streaming is explicitly enabled
|
|
||||||
streaming_enabled = getattr(self, "streaming", None) is True
|
|
||||||
|
|
||||||
# Check if pretraining dataset exists (defaults to streaming)
|
|
||||||
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
|
|
||||||
streaming_default_for_pretraining = (
|
|
||||||
has_pretraining and getattr(self, "streaming", None) is None
|
|
||||||
)
|
|
||||||
|
|
||||||
# If streaming is enabled (explicitly or by default for pretraining)
|
|
||||||
if streaming_enabled or streaming_default_for_pretraining:
|
|
||||||
max_steps = getattr(self, "max_steps", None)
|
|
||||||
if not max_steps:
|
|
||||||
raise ValueError("max_steps must be set when using streaming datasets")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationMixin(
|
class ValidationMixin(
|
||||||
DatasetValidationMixin,
|
DatasetValidationMixin,
|
||||||
AttentionValidationMixin,
|
AttentionValidationMixin,
|
||||||
@@ -1371,7 +1347,6 @@ class ValidationMixin(
|
|||||||
SystemValidationMixin,
|
SystemValidationMixin,
|
||||||
ChatTemplateValidationMixin,
|
ChatTemplateValidationMixin,
|
||||||
PretrainingValidationMixin,
|
PretrainingValidationMixin,
|
||||||
StreamingValidationMixin,
|
|
||||||
ModelCompatibilityValidationMixin,
|
ModelCompatibilityValidationMixin,
|
||||||
ComplexValidationMixin,
|
ComplexValidationMixin,
|
||||||
GRPOVllmValidationMixin,
|
GRPOVllmValidationMixin,
|
||||||
|
|||||||
@@ -48,7 +48,13 @@ class TestBatchedSamplerPacking:
|
|||||||
max_seq_length,
|
max_seq_length,
|
||||||
sequential,
|
sequential,
|
||||||
):
|
):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||||
|
apply_multipack_dataloader_patch,
|
||||||
|
remove_multipack_dataloader_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply the patch for multipack handling
|
||||||
|
apply_multipack_dataloader_patch()
|
||||||
|
|
||||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||||
|
|
||||||
@@ -101,10 +107,14 @@ class TestBatchedSamplerPacking:
|
|||||||
for pack in batch:
|
for pack in batch:
|
||||||
batch_idxs.extend(pack)
|
batch_idxs.extend(pack)
|
||||||
|
|
||||||
for batch in loader:
|
try:
|
||||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
for batch in loader:
|
||||||
assert batch["input_ids"].shape[1] == max_seq_length
|
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||||
|
assert batch["input_ids"].shape[1] == max_seq_length
|
||||||
|
|
||||||
original_idxs = set(range(len(train_dataset)))
|
original_idxs = set(range(len(train_dataset)))
|
||||||
assert original_idxs == set(batch_idxs)
|
assert original_idxs == set(batch_idxs)
|
||||||
assert len(batch_idxs) == len(set(batch_idxs))
|
assert len(batch_idxs) == len(set(batch_idxs))
|
||||||
|
finally:
|
||||||
|
# Clean up: remove the patch after the test
|
||||||
|
remove_multipack_dataloader_patch()
|
||||||
|
|||||||
Reference in New Issue
Block a user