Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
bb65157dcf fix conditional for None values 2025-08-17 12:49:48 -04:00
Wing Lian
7fd3d8abc4 handle batch size correchtly when using split and dispatch batches 2025-08-16 22:05:31 -04:00
27 changed files with 53 additions and 207 deletions

View File

@@ -12,6 +12,5 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
drafts: false drafts: false
auto_incremental_review: true
chat: chat:
auto_reply: true auto_reply: true

View File

@@ -41,12 +41,6 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
``` ```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -67,23 +61,9 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model ### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM. for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -13,8 +13,8 @@ liger-kernel==0.6.1
packaging==23.2 packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft>=0.17.0 peft==0.17.0
transformers==4.55.3 transformers==4.55.2
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.10.0 accelerate==1.10.0
datasets==4.0.0 datasets==4.0.0

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.3"], "flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.3", "flash-attn==2.8.2",
"ring-flash-attn>=0.1.7", "ring-flash-attn>=0.1.7",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],

View File

@@ -40,12 +40,6 @@ class VllmServeCliArgs:
default=None, default=None,
metadata={"help": "Number of tensor parallel workers to use."}, metadata={"help": "Number of tensor parallel workers to use."},
) )
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field( host: Optional[str] = field(
default=None, # nosec B104 default=None, # nosec B104
metadata={"help": "Host address to run the server on."}, metadata={"help": "Host address to run the server on."},

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1" docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]: if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count) return modal.gpu.A10G(count=count)
if family == "h100": if family == "h100":
return f"H100:{count}" return modal.gpu.H100(count=count)
if family == "t4": if family == "t4":
return modal.gpu.T4(count=count) return modal.gpu.T4(count=count)
if family == "l4": if family == "l4":

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template": elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -97,8 +97,7 @@ def do_cli(
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1" os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True) parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,12 +3,11 @@
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from typing import Any
def generate_sweep_configs( def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list] base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, Any]]: ) -> list[dict[str, list]]:
""" """
Recursively generates all possible configurations by applying sweeps to the base config. Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -4,7 +4,6 @@ import os
import subprocess # nosec import subprocess # nosec
import sys import sys
import tempfile import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal from typing import Any, Iterator, Literal
import yaml import yaml
@@ -89,12 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations # Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config) permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1 is_group = len(permutations) > 1
base_output_dir = base_config.get("output_dir", "./model-out") for permutation in permutations:
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with # pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile( temp_file = tempfile.NamedTemporaryFile(
mode="w", mode="w",

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from datasets import Dataset from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.data import prepare_datasets, prepare_preference_datasets

View File

@@ -424,7 +424,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
): ):
if training_args.pretraining: if training_args.pretraining:
if ( if (
self.cfg.pretraining_sample_concatenation is False not self.cfg.pretraining_sample_concatenation
or self.cfg.micro_batch_size > 1 or self.cfg.micro_batch_size > 1
): ):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
@@ -476,8 +476,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
): ):
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
if self.cfg.squash_position_ids:
kwargs["squash_position_ids"] = True
else: else:
collator = BatchSamplerDataCollatorForSeq2Seq collator = BatchSamplerDataCollatorForSeq2Seq
else: else:

View File

@@ -272,6 +272,20 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index, rank=self.args.process_index,
) )
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and ( if self.args.sample_packing and (
(is_training and not self.args.pretraining) (is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False) or (not is_training and self.args.eval_sample_packing is not False)

View File

View File

@@ -277,14 +277,6 @@ class PatchManager:
has_remote_code=has_remote_code, has_remote_code=has_remote_code,
) )
if self.cfg.sample_packing:
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
apply_multipack_dataloader_patch,
)
LOG.info("Applying multipack dataloader patch for sample packing...")
apply_multipack_dataloader_patch()
def _apply_fsdp2_bnb_patches(self): def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches.""" """Apply FSDP2 BNB patches."""
if ( if (

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight # wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype: if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to( module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -1,4 +1,4 @@
"""Monkey patches for the dataset fetcher to handle batches of packed indexes.""" """monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access # pylint: disable=protected-access
@@ -6,20 +6,10 @@ import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop from torch.utils.data._utils.worker import _worker_loop
_ORIGINAL_MAP_DATASET_FETCHER = None
_ORIGINAL_WORKER_LOOP = None
_IS_PATCHED = False
class _MapDatasetFetcher(_BaseDatasetFetcher): class _MapDatasetFetcher(_BaseDatasetFetcher):
"""
Custom dataset fetcher that handles nested batch structures from
MultipackBatchSampler.
"""
def fetch(self, possibly_batched_index): def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list): if isinstance(possibly_batched_index[0], list):
# Handle nested structure from MultipackBatchSampler
data = [None for i in possibly_batched_index] data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index): for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation: if self.auto_collation:
@@ -33,7 +23,6 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
else: else:
data[i] = self.dataset[possibly_batched_index_] data[i] = self.dataset[possibly_batched_index_]
else: else:
# Standard batch handling
if self.auto_collation: if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index) data = self.dataset.__getitems__(possibly_batched_index)
@@ -45,54 +34,14 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
def patch_fetchers(): def patch_fetchers():
"""Apply patches to PyTorch's DataLoader components."""
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs): def patched_worker_loop(*args, **kwargs):
"""Worker loop that ensures patches are applied in worker processes."""
patch_fetchers() patch_fetchers()
return _worker_loop(*args, **kwargs) return _worker_loop(*args, **kwargs)
def apply_multipack_dataloader_patch():
"""
This patch allows DataLoader to correctly process batches that contain multiple bins
of packed sequences.
"""
# pylint: disable=global-statement
global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED
if _IS_PATCHED:
return
# Store original implementations
_ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher
_ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop
# Apply patches
patch_fetchers()
torch.utils.data._utils.worker._worker_loop = patched_worker_loop torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()
_IS_PATCHED = True
def remove_multipack_dataloader_patch():
"""Remove the monkeypatch and restore original PyTorch DataLoader behavior."""
# pylint: disable=global-statement
global _IS_PATCHED
if not _IS_PATCHED:
return
if _ORIGINAL_MAP_DATASET_FETCHER:
torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (
_ORIGINAL_MAP_DATASET_FETCHER
)
if _ORIGINAL_WORKER_LOOP:
torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
_IS_PATCHED = False

View File

@@ -253,9 +253,7 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return return
if ( # pylint: disable=too-many-nested-blocks if trainer.is_fsdp_enabled or cfg.fsdp_config:
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if cfg.fsdp_config or cfg.fsdp: if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type: if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -287,8 +285,6 @@ def save_trained_model(
if trainer.accelerator.is_main_process: if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir # move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir(): for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir) shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207 # TODO(wing):see https://github.com/huggingface/transformers/pull/40207

View File

@@ -28,7 +28,7 @@ from axolotl.utils.data.shared import (
) )
from axolotl.utils.data.utils import ( from axolotl.utils.data.utils import (
deduplicate_and_log_datasets, deduplicate_and_log_datasets,
handle_long_seq_in_dataset, drop_long_seq_in_dataset,
retry_on_request_exceptions, retry_on_request_exceptions,
) )
from axolotl.utils.data.wrappers import get_dataset_wrapper from axolotl.utils.data.wrappers import get_dataset_wrapper
@@ -339,9 +339,9 @@ def _load_raw_datasets(
if not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len: if split == "test" and cfg.eval_sequence_len:
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else: else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing: if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)

View File

@@ -148,36 +148,7 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset return dataset, other_dataset
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): def drop_long_seq_in_dataset(
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
results = []
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset: ) -> Dataset:
"""Remove sequences longer than configured maximum from dataset. """Remove sequences longer than configured maximum from dataset.
@@ -221,21 +192,8 @@ def handle_long_seq_in_dataset(
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter( dataset = dataset.filter(
process_fn, drop_long,
batched=True, batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
@@ -243,11 +201,6 @@ def handle_long_seq_in_dataset(
if prior_len: if prior_len:
dropped = prior_len - len(dataset) dropped = prior_len - len(dataset)
if dropped: if dropped:
action = ( LOG.warning(f"Dropped {dropped} long samples from dataset")
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset return dataset

View File

@@ -414,12 +414,6 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
}, },
) )
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field( eval_sequence_len: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -459,12 +453,6 @@ class AxolotlInputConfig(
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'" "description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
}, },
) )
squash_position_ids: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to squash position_ids for packing, effectively extending context length."
},
)
eval_sample_packing: bool | None = Field( eval_sample_packing: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import handle_long_seq_in_dataset from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -48,13 +48,7 @@ class TestBatchedSamplerPacking:
max_seq_length, max_seq_length,
sequential, sequential,
): ):
from axolotl.monkeypatch.data.batch_dataset_fetcher import ( import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
apply_multipack_dataloader_patch,
remove_multipack_dataloader_patch,
)
# Apply the patch for multipack handling
apply_multipack_dataloader_patch()
dataset = dataset_winglian_tiny_shakespeare["train"] dataset = dataset_winglian_tiny_shakespeare["train"]
@@ -76,7 +70,7 @@ class TestBatchedSamplerPacking:
) )
train_dataset = concatenate_datasets([dataset_wrapper]) train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset) lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler( batch_sampler = MultipackBatchSampler(
@@ -107,7 +101,6 @@ class TestBatchedSamplerPacking:
for pack in batch: for pack in batch:
batch_idxs.extend(pack) batch_idxs.extend(pack)
try:
for batch in loader: for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length assert batch["input_ids"].shape[1] == max_seq_length
@@ -115,6 +108,3 @@ class TestBatchedSamplerPacking:
original_idxs = set(range(len(train_dataset))) original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs) assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs)) assert len(batch_idxs) == len(set(batch_idxs))
finally:
# Clean up: remove the patch after the test
remove_multipack_dataloader_patch()