Compare commits

...

31 Commits

Author SHA1 Message Date
mhenrhcsen
0f2d196476 Remove deprecated configuration files: deleted config.qmd and finetune copy.yml to streamline project structure and eliminate unused resources. 2025-08-12 21:23:34 +02:00
mhenrhcsen
f1a8474400 Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency. 2025-08-12 21:20:48 +02:00
mhenrhcsen
dc5887c652 pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting 2025-08-12 21:12:07 +02:00
mhenrhcsen
54b542d312 remove unused files 2025-08-12 21:09:40 +02:00
mhenrhcsen
30a89b07b9 Refactor AxolotlInputConfig: clean up sequence_len and sequence_len_overflow_handling fields, ensuring consistent descriptions and removing conflict markers. 2025-08-12 21:03:28 +02:00
mhenrhcsen
746c03b097 Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling 2025-08-12 20:53:28 +02:00
mhenrhcsen
47b3fe8af3 Resolve merge conflicts: unify pretraining utils imports, add alias handling; fix rl.py per new RL dataset API; resolve config schema conflict and add sequence_len_overflow_handling field 2025-08-12 20:45:26 +02:00
mhenrhcsen
f5a3e3529e RL datasets: warn and drop unsalvageable over-length prompts post-truncate; add post-truncate filter; support alias config key 'excess_token_handling' 2025-08-12 20:37:41 +02:00
Wing Lian
3d45620008 remove prepare-from-posids patch (#3052) [skip ci] 2025-08-11 09:34:41 -04:00
github-actions[bot]
ce20e838b5 chore: update pre-commit hooks (#3050) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-08-11 09:32:21 -04:00
Wing Lian
d4d84d48af fix ray train and add fsdp2 smoke test for ray trainer (#3053)
* add fsdp2 smokle test for ray trainer

* fix raytrain with fsdp2
2025-08-11 09:31:54 -04:00
Wing Lian
9b12c05660 use exec instead of subprocess to make ctrl+c nicer for cli (#3044)
* use exec instead of subprocess to make ctrl+c nicer for cli

* change var name to use_exec

* simplify to bool

* flush std*

* patch subprocess as mock in test

* fix tests

* more test fixes
2025-08-10 20:22:20 -04:00
Wing Lian
686933194e fix vllm tagging and add cloud images w/o tmux (#3049) [skip ci] 2025-08-10 20:21:56 -04:00
Wing Lian
d12b461d19 follow up fix for plugin registration (#3054) [skip ci] 2025-08-10 20:21:38 -04:00
Wing Lian
d6b81b3683 update training args check for new defaults (#3051) [skip ci]
* update training args check for new defaults

* skip check for now
2025-08-10 11:26:22 -04:00
Wing Lian
05f1b4b2e8 run monkeypatch tests in seperate runner (#3047) 2025-08-09 14:34:07 -04:00
Wing Lian
7cfc80ec77 set dev version (#3045) [skip ci] 2025-08-08 13:56:53 -04:00
salman
0da6a95efa Add citation.tff (#3043) [skip ci] 2025-08-08 16:18:42 +01:00
mhenrichsen
618b008e36 Merge branch 'main' into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-27 12:31:31 +02:00
mhenrhcsen
5d7a61576d Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py.
- Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
2025-05-15 12:55:09 +02:00
mhenrhcsen
5ecf22b54e Merge branch 'main' of github.com:axolotl-ai-cloud/axolotl into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-14 13:36:43 +02:00
mhenrhcsen
9c5b8da22f fix merge conflicts 2025-05-14 13:33:42 +02:00
mhenrhcsen
fea6649518 increased test coverage 2025-05-13 08:58:34 +02:00
mhenrhcsen
124ad2b968 lint 2025-05-13 08:35:16 +02:00
mhenrhcsen
767c2340f1 docstring for tests 2025-05-12 22:57:43 +02:00
mhenrhcsen
f6623c34cc Linting fix 2025-05-12 22:53:30 +02:00
mhenrhcsen
5dd8f0b2b8 Fixes comments from winglian 2025-05-12 22:43:15 +02:00
mhenrhcsen
be3c6bbd85 fix linting issues 2025-05-12 14:46:57 +02:00
mhenrhcsen
f07db4f853 Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt.
- Improved readability by restructuring the code and removing redundant checks.
- Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
2025-05-12 14:40:10 +02:00
mhenrhcsen
17a5838d38 lint 2025-05-12 14:36:43 +02:00
mhenrhcsen
9f68918f13 Implement configurable handling of excess tokens in datasets
- Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length.
- Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method.
- Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior.
- Enhanced logging to reflect truncation actions in dataset processing.

This change improves flexibility in managing sequence lengths during training and evaluation.
2025-05-12 14:08:43 +02:00
24 changed files with 920 additions and 158 deletions

View File

@@ -98,6 +98,12 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -151,6 +157,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -105,7 +105,8 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
@@ -179,8 +180,8 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.7
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

10
CITATION.cff Normal file
View File

@@ -0,0 +1,10 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
url: "https://axolotl.ai/"
license: Apache-2.0
date-released: "2023-05-30"

View File

@@ -149,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
```
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.12.0"
__version__ = "0.13.0.dev"

View File

@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(

View File

@@ -123,9 +123,10 @@ def train(
_launcher = None if kwargs.get("use_ray") else launcher
# Process each configuration
for cfg_file in generate_config_files(config, sweep):
for cfg_file, is_group in generate_config_files(config, sweep):
try:
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:

View File

@@ -2,6 +2,7 @@
import os
import subprocess # nosec
import sys
import tempfile
from typing import Any, Iterator, Literal
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process.
Args:
config: Base configuration file
sweep: Sweep configuration file
Yields:
Tuple of configuration file name and whether this is a group of configurations
"""
if not sweep:
yield config
yield config, False
return
# Load sweep and base configurations
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name
yield temp_file.name, is_group
def launch_training(
@@ -97,6 +109,7 @@ def launch_training(
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
@@ -105,11 +118,14 @@ def launch_training(
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
elif launcher is None:
# handle ray train launch
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
@@ -136,7 +152,10 @@ def _launch_cloud_training(
def _launch_accelerate_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
@@ -161,11 +180,20 @@ def _launch_accelerate_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
@@ -178,7 +206,13 @@ def _launch_torchrun_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:

View File

@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init

View File

@@ -76,8 +76,8 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
def register(self, cfg: dict): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
cfg: The configuration for the plugin.

View File

@@ -73,9 +73,6 @@ class PatchManager:
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
@@ -87,7 +84,6 @@ class PatchManager:
and self.cfg.fsdp_version == 2
)
patch_prepare_from_posids()
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()

View File

@@ -1,87 +0,0 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -10,6 +10,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
@@ -259,6 +260,15 @@ def encode_packed_pretraining(
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=(
getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling",
getattr(ds_wrapper, "cfg", {}).get(
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)
),
)
sampler = MultipackBatchSampler(

View File

@@ -122,6 +122,14 @@ def _map_dataset(
return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
"""
Backward-compatibility wrapper for legacy imports in tests.
Delegates to the new predicate.
"""
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
@@ -155,11 +163,51 @@ def _drop_long_sequences(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
# Truncate first, then drop if still invalid (although truncate should handle it)
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len:
result = sample
else:
# Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt
if rl is RLType.KTO:
if max_response_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
len_prompt,
sequence_len,
)
result = False
else:
# Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len:
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len:
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -171,12 +219,86 @@ def _drop_long_sequences(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
# Truncate first
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
else:
# Calculate maximum completion length
max_completion_len = sequence_len - len_prompt
if rl is RLType.GRPO:
return True
if max_completion_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
len_prompt,
sequence_len,
)
result = False
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
completion_tokens = tokenizer(
completion, add_special_tokens=False
)["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
elif rl == RLType.GRPO:
# For GRPO always keep
result = True
else:
raise ValueError("Unknown RL type")
return bool(result)
def load_prepare_preference_datasets(cfg):
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
"""
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
Used with dataset.filter() after truncation to drop unsalvageable samples.
"""
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
return False
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(
tokenizer(rejected, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
return False
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
if rl == RLType.GRPO:
# GRPO does not enforce this check here
return True
return False
# Legacy shim preserved for backward compatibility; no-op in new flow
def load_split(dataset_cfgs, _cfg): # noqa: F811
return None
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:

View File

@@ -15,10 +15,12 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = get_logger(__name__)
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
class RetryStrategy(Enum):
"""Enum for retry strategies."""
@@ -168,10 +170,19 @@ def drop_long_seq_in_dataset(
)
return dataset
drop_long = functools.partial(
drop_long_seq,
# Get the handling method from config, default to "drop" for backward compatibility.
# Support legacy alias "excess_token_handling" as well.
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING),
)
# Use the function with the specified handling mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
with contextlib.suppress(AttributeError):
@@ -190,17 +201,31 @@ def drop_long_seq_in_dataset(
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
if handling == "truncate":
drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode
dataset = dataset.filter(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -414,6 +414,12 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop",
json_schema_extra={
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={

View File

@@ -233,6 +233,114 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return results
def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
If handling is "drop":
- Samples that are too short or too long will be dropped
If handling is "truncate":
- Samples that are too short will still be dropped
- Samples that are too long will be truncated to sequence_len
Works for both single-example (list[int]) or batched (list[list[int]]).
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
"""
min_sequence_len = min_sequence_len or 2
result = None
if handling == "drop":
return drop_long_seq(sample, sequence_len, min_sequence_len)
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
result = False if handling == "drop" else sample
# Single example (input_ids is a list of int)
elif isinstance(input_ids[0], int):
length = len(input_ids)
# Handle samples that are too short - always drop them
if length < min_sequence_len:
result = False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it
elif length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
# Also truncate labels if present
if "labels" in sample:
sample["labels"] = sample["labels"][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"] = sample["position_ids"][:sequence_len]
# Update length if present
if "length" in sample:
sample["length"] = sequence_len
result = sample
# For drop mode or if the sample doesn't exceed max length
else:
result = (
min_sequence_len <= length <= sequence_len
if handling == "drop"
else sample
)
# Batched (input_ids is a list of lists)
else:
if handling == "drop":
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
result = results
else: # truncate
# Check each sequence in the batch
for i, seq in enumerate(input_ids):
length = len(seq)
# Skip sequences that are too short
if length < min_sequence_len:
continue
# Truncate sequences that are too long
if length > sequence_len:
input_ids[i] = seq[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][
:sequence_len
]
# Also truncate labels if present
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][
:sequence_len
]
# Update length if present
if "length" in sample:
sample["length"][i] = sequence_len
result = sample
return result
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
@@ -368,15 +476,33 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
train_dataset,
sequence_len,
skip_position_ids=True,
drop_attention_mask=False,
handling="drop",
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
# Define the function to use for handling sequences based on the mode
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
train_dataset = train_dataset.map(
seq_handler_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False,
)
else: # handling == "drop"
train_dataset = train_dataset.filter(
seq_handler_fn, # Use the same function, it returns boolean for drop mode
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,

View File

@@ -47,7 +47,9 @@ class BaseCliTest:
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
with patch(mock_fn) as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
@@ -65,8 +67,12 @@ class BaseCliTest:
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
if command == "train":
assert mock.call_args.args[0] == "accelerate"
assert mock.call_args.args[1] == expected
else:
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):

View File

@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "torchrun"
called_cmd = mock_subprocess.call_args.args[1]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present

View File

@@ -281,7 +281,9 @@ class TestHFRLTrainerBuilder:
# Other settings
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
# TODO(wing): restore once trl releases 0.22.0
# assert training_arguments.gradient_checkpointing is True
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)

View File

@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from tests.e2e.utils import (
check_tensorboard,
require_torch_2_7_0,
require_torch_lt_2_6_0,
)
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -139,3 +143,71 @@ class TestMultiGPURay:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"save_first_step": False,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
"""
import unittest
from unittest.mock import MagicMock
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data.rl import drop_long_rl_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -64,5 +66,254 @@ class TestEncodePretraining(unittest.TestCase):
)
class TestDropLongRLSeq(unittest.TestCase):
"""
Tests for the drop_long_rl_seq function.
"""
def setUp(self):
# Mock tokenizer that returns length based on input string length
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
["x"] * len(tokens)
) # pylint: disable=unused-argument
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""Test DPO drop mode with a valid sample."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""Test DPO drop mode with chosen too long."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
"rejected": "r" * 6,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""Test DPO drop mode with rejected too long."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 16,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""Test DPO truncate mode when no truncation is needed."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(
result, original_sample
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""Test DPO truncate mode when the prompt itself is too long."""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
# Even though truncation isn't possible, the function should return the original sample
# for the map operation, assuming downstream filtering will catch it.
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""Test DPO truncate mode when only 'chosen' needs truncation."""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 18,
"rejected": "r" * 10,
} # 5+18=23 > 20, 5+10=15 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["chosen"], "x" * max_resp_len
) # Check decoded truncated value
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""Test DPO truncate mode when only 'rejected' needs truncation."""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 10,
"rejected": "r" * 18,
} # 5+10=15 <= 20, 5+18=23 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), 10) # Unchanged
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["rejected"], "x" * max_resp_len
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 15,
"rejected": "r" * 14,
} # 8+15=23 > 20, 8+14=22 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
# Let's test a case where one is slightly too long causing the initial fail,
# but the other fits *within* the max_response_len, so only one gets truncated.
prompt_len = 10
max_resp_len = self.sequence_len - prompt_len # 10
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 11,
"rejected": "r" * 9,
} # 10+11=21 > 20, 10+9=19 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""Test KTO drop mode with a valid sample."""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""Test KTO drop mode with an invalid sample."""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""Test KTO truncate mode when no truncation is needed."""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""Test KTO truncate mode when the prompt itself is too long."""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""Test KTO truncate mode when completion needs truncation."""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""Test ValueError raised if keys missing for DPO."""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""Test ValueError raised if keys missing for KTO."""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""Test ValueError raised for unknown RL type."""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""Test GRPO drop mode (currently always True)."""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_grpo_truncate(self):
"""Test GRPO truncate mode (currently returns original sample)."""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, sample)
if __name__ == "__main__":
unittest.main()

153
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,153 @@
"""Module containing tests for trainer utility functions."""
import unittest
from functools import partial
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {
"input_ids": list(range(self.min_sequence_len)),
"labels": list(range(self.min_sequence_len)),
}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()