Compare commits
31 Commits
v0.12.0
...
775-option
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f2d196476 | ||
|
|
f1a8474400 | ||
|
|
dc5887c652 | ||
|
|
54b542d312 | ||
|
|
30a89b07b9 | ||
|
|
746c03b097 | ||
|
|
47b3fe8af3 | ||
|
|
f5a3e3529e | ||
|
|
3d45620008 | ||
|
|
ce20e838b5 | ||
|
|
d4d84d48af | ||
|
|
9b12c05660 | ||
|
|
686933194e | ||
|
|
d12b461d19 | ||
|
|
d6b81b3683 | ||
|
|
05f1b4b2e8 | ||
|
|
7cfc80ec77 | ||
|
|
0da6a95efa | ||
|
|
618b008e36 | ||
|
|
5d7a61576d | ||
|
|
5ecf22b54e | ||
|
|
9c5b8da22f | ||
|
|
fea6649518 | ||
|
|
124ad2b968 | ||
|
|
767c2340f1 | ||
|
|
f6623c34cc | ||
|
|
5dd8f0b2b8 | ||
|
|
be3c6bbd85 | ||
|
|
f07db4f853 | ||
|
|
17a5838d38 | ||
|
|
9f68918f13 |
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -98,6 +98,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -151,6 +157,18 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -105,7 +105,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
@@ -179,8 +180,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -23,7 +23,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.7
|
||||
rev: v3.3.8
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
|
||||
10
CITATION.cff
Normal file
10
CITATION.cff
Normal file
@@ -0,0 +1,10 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Post-Training for AI Models"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||
url: "https://axolotl.ai/"
|
||||
license: Apache-2.0
|
||||
date-released: "2023-05-30"
|
||||
14
README.md
14
README.md
@@ -149,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
## 📝 Citing Axolotl
|
||||
|
||||
If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Post-Training for AI Models},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.0"
|
||||
__version__ = "0.13.0.dev"
|
||||
|
||||
@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.cfg = cfg
|
||||
# now that we have the finalized cfg, register the plugins individually
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def load_cfg(
|
||||
|
||||
@@ -123,9 +123,10 @@ def train(
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||
try:
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
use_exec = is_group is not True
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process.
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
|
||||
Yields:
|
||||
Tuple of configuration file name and whether this is a group of configurations
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
@@ -97,6 +109,7 @@ def launch_training(
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -105,11 +118,14 @@ def launch_training(
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
elif launcher is None:
|
||||
# handle ray train launch
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
@@ -136,7 +152,10 @@ def _launch_cloud_training(
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -161,11 +180,20 @@ def _launch_accelerate_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -178,7 +206,13 @@ def _launch_torchrun_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
|
||||
@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -76,8 +76,8 @@ class BasePlugin:
|
||||
def __init__(self):
|
||||
"""Initializes the BasePlugin."""
|
||||
|
||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration.
|
||||
def register(self, cfg: dict): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration as an unparsed dict.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
|
||||
@@ -73,9 +73,6 @@ class PatchManager:
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||
patch_prepare_from_posids,
|
||||
)
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
patch_evaluation_loop,
|
||||
patch_maybe_log_save_evaluate,
|
||||
@@ -87,7 +84,6 @@ class PatchManager:
|
||||
and self.cfg.fsdp_version == 2
|
||||
)
|
||||
|
||||
patch_prepare_from_posids()
|
||||
patch_evaluation_loop(patch_fsdp2)
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/39653/files
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||
),
|
||||
)
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
indices_q,
|
||||
(cu_seq_lens, cu_seq_lens),
|
||||
(max_length, max_length),
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_from_posids():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||
_prepare_from_posids
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_prepare_from_posids",
|
||||
_prepare_from_posids,
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||
@@ -259,6 +260,15 @@ def encode_packed_pretraining(
|
||||
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
|
||||
# workaround by using the position id logic for now in trainer
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=(
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling",
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
|
||||
@@ -122,6 +122,14 @@ def _map_dataset(
|
||||
return dataset
|
||||
|
||||
|
||||
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||
"""
|
||||
Backward-compatibility wrapper for legacy imports in tests.
|
||||
Delegates to the new predicate.
|
||||
"""
|
||||
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
|
||||
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
@@ -155,11 +163,51 @@ def _drop_long_sequences(
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum response length that can fit with the prompt
|
||||
max_response_len = sequence_len - len_prompt
|
||||
|
||||
if rl is RLType.KTO:
|
||||
if max_response_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = False
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
if len_chosen > max_response_len:
|
||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["chosen"] = tokenizer.decode(
|
||||
chosen_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
if len_rejected > max_response_len:
|
||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["rejected"] = tokenizer.decode(
|
||||
rejected_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
elif rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
@@ -171,12 +219,86 @@ def _drop_long_sequences(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
# Truncate first
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum completion length
|
||||
max_completion_len = sequence_len - len_prompt
|
||||
|
||||
if rl is RLType.GRPO:
|
||||
return True
|
||||
if max_completion_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = False
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
if len_completion > max_completion_len:
|
||||
completion_tokens = tokenizer(
|
||||
completion, add_special_tokens=False
|
||||
)["input_ids"][:max_completion_len]
|
||||
sample["completion"] = tokenizer.decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
elif rl == RLType.GRPO:
|
||||
# For GRPO always keep
|
||||
result = True
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
return bool(result)
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
||||
"""
|
||||
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
|
||||
Used with dataset.filter() after truncation to drop unsalvageable samples.
|
||||
"""
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
rejected = sample["rejected"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(
|
||||
tokenizer(rejected, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
if rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
completion = sample["completion"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_completion = len(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
if rl == RLType.GRPO:
|
||||
# GRPO does not enforce this check here
|
||||
return True
|
||||
return False
|
||||
|
||||
# Legacy shim preserved for backward compatibility; no-op in new flow
|
||||
def load_split(dataset_cfgs, _cfg): # noqa: F811
|
||||
return None
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
|
||||
@@ -15,10 +15,12 @@ from datasets import Dataset, IterableDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""Enum for retry strategies."""
|
||||
@@ -168,10 +170,19 @@ def drop_long_seq_in_dataset(
|
||||
)
|
||||
return dataset
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
# Get the handling method from config, default to "drop" for backward compatibility.
|
||||
# Support legacy alias "excess_token_handling" as well.
|
||||
handling = cfg.get(
|
||||
"sequence_len_overflow_handling",
|
||||
cfg.get("excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING),
|
||||
)
|
||||
|
||||
# Use the function with the specified handling mode
|
||||
seq_handler = functools.partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
handling=handling,
|
||||
)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
@@ -190,17 +201,31 @@ def drop_long_seq_in_dataset(
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
if handling == "truncate":
|
||||
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||
else: # handling == "drop"
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
if handling == "truncate":
|
||||
# Use map for truncate mode
|
||||
dataset = dataset.map(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long samples in dataset to {sequence_len} tokens")
|
||||
else: # handling == "drop"
|
||||
# Use filter for drop mode
|
||||
dataset = dataset.filter(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -414,6 +414,12 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
|
||||
default="drop",
|
||||
json_schema_extra={
|
||||
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -233,6 +233,114 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return results
|
||||
|
||||
|
||||
def truncate_or_drop_long_seq(
|
||||
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
|
||||
):
|
||||
"""
|
||||
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
If handling is "drop":
|
||||
- Samples that are too short or too long will be dropped
|
||||
If handling is "truncate":
|
||||
- Samples that are too short will still be dropped
|
||||
- Samples that are too long will be truncated to sequence_len
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
result = None
|
||||
|
||||
if handling == "drop":
|
||||
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Edge case: if input_ids is empty
|
||||
if not input_ids:
|
||||
result = False if handling == "drop" else sample
|
||||
# Single example (input_ids is a list of int)
|
||||
elif isinstance(input_ids[0], int):
|
||||
length = len(input_ids)
|
||||
|
||||
# Handle samples that are too short - always drop them
|
||||
if length < min_sequence_len:
|
||||
result = False if handling == "drop" else sample
|
||||
# If truncation is enabled and the sample is too long, truncate it
|
||||
elif length > sequence_len and handling == "truncate":
|
||||
sample["input_ids"] = input_ids[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"] = sample["labels"][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"] = sample["position_ids"][:sequence_len]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"] = sequence_len
|
||||
|
||||
result = sample
|
||||
# For drop mode or if the sample doesn't exceed max length
|
||||
else:
|
||||
result = (
|
||||
min_sequence_len <= length <= sequence_len
|
||||
if handling == "drop"
|
||||
else sample
|
||||
)
|
||||
# Batched (input_ids is a list of lists)
|
||||
else:
|
||||
if handling == "drop":
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
result = results
|
||||
else: # truncate
|
||||
# Check each sequence in the batch
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
|
||||
# Skip sequences that are too short
|
||||
if length < min_sequence_len:
|
||||
continue
|
||||
|
||||
# Truncate sequences that are too long
|
||||
if length > sequence_len:
|
||||
input_ids[i] = seq[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"][i] = sequence_len
|
||||
|
||||
result = sample
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||
if drop_attn_mask:
|
||||
@@ -368,15 +476,33 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
|
||||
|
||||
def process_pretraining_datasets_for_packing(
|
||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||
train_dataset,
|
||||
sequence_len,
|
||||
skip_position_ids=True,
|
||||
drop_attention_mask=False,
|
||||
handling="drop",
|
||||
):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
# Define the function to use for handling sequences based on the mode
|
||||
seq_handler_fn = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
handling=handling, # Pass handling mode
|
||||
)
|
||||
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
train_dataset = train_dataset.map(
|
||||
seq_handler_fn,
|
||||
desc="Truncating Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
else: # handling == "drop"
|
||||
train_dataset = train_dataset.filter(
|
||||
seq_handler_fn, # Use the same function, it returns boolean for drop mode
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
if not skip_position_ids:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
|
||||
@@ -47,7 +47,9 @@ class BaseCliTest:
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||
|
||||
with patch(mock_fn) as mock:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
@@ -65,8 +67,12 @@ class BaseCliTest:
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
if command == "train":
|
||||
assert mock.call_args.args[0] == "accelerate"
|
||||
assert mock.call_args.args[1] == expected
|
||||
else:
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||
|
||||
@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
|
||||
@@ -281,7 +281,9 @@ class TestHFRLTrainerBuilder:
|
||||
# Other settings
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
assert training_arguments.gradient_checkpointing is False
|
||||
|
||||
# TODO(wing): restore once trl releases 0.22.0
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
|
||||
@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -139,3 +143,71 @@ class TestMultiGPURay:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--use-ray",
|
||||
"--ray-num-workers",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_pretraining, md5
|
||||
from axolotl.utils.data.rl import drop_long_rl_seq
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -64,5 +66,254 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestDropLongRLSeq(unittest.TestCase):
|
||||
"""
|
||||
Tests for the drop_long_rl_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Mock tokenizer that returns length based on input string length
|
||||
self.tokenizer = MagicMock()
|
||||
|
||||
def side_effect_func(
|
||||
text, add_special_tokens=False
|
||||
): # pylint: disable=unused-argument
|
||||
return {"input_ids": list(range(len(text)))}
|
||||
|
||||
self.tokenizer.side_effect = side_effect_func
|
||||
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
|
||||
["x"] * len(tokens)
|
||||
) # pylint: disable=unused-argument
|
||||
|
||||
self.sequence_len = 20
|
||||
|
||||
def test_dpo_drop_mode_valid(self):
|
||||
"""Test DPO drop mode with a valid sample."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_chosen(self):
|
||||
"""Test DPO drop mode with chosen too long."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 16,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_rejected(self):
|
||||
"""Test DPO drop mode with rejected too long."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 16,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed(self):
|
||||
"""Test DPO truncate mode when no truncation is needed."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(
|
||||
result, original_sample
|
||||
) # Should return the original sample unchanged
|
||||
|
||||
def test_dpo_truncate_mode_prompt_too_long(self):
|
||||
"""Test DPO truncate mode when the prompt itself is too long."""
|
||||
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
# Even though truncation isn't possible, the function should return the original sample
|
||||
# for the map operation, assuming downstream filtering will catch it.
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_dpo_truncate_mode_chosen_truncated(self):
|
||||
"""Test DPO truncate mode when only 'chosen' needs truncation."""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 18,
|
||||
"rejected": "r" * 10,
|
||||
} # 5+18=23 > 20, 5+10=15 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["chosen"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
self.assertEqual(len(result["rejected"]), 10) # Unchanged
|
||||
|
||||
def test_dpo_truncate_mode_rejected_truncated(self):
|
||||
"""Test DPO truncate mode when only 'rejected' needs truncation."""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 10,
|
||||
"rejected": "r" * 18,
|
||||
} # 5+10=15 <= 20, 5+18=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), 10) # Unchanged
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["rejected"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
|
||||
def test_dpo_truncate_mode_both_truncated(self):
|
||||
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
|
||||
prompt_len = 8
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 15,
|
||||
"rejected": "r" * 14,
|
||||
} # 8+15=23 > 20, 8+14=22 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["rejected"], "x" * max_resp_len)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
|
||||
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
|
||||
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
|
||||
# but the initial check failed because e.g. prompt + chosen > sequence_len
|
||||
# The current logic *will* truncate if len(chosen) > max_resp_len.
|
||||
# Let's test a case where one is slightly too long causing the initial fail,
|
||||
# but the other fits *within* the max_response_len, so only one gets truncated.
|
||||
prompt_len = 10
|
||||
max_resp_len = self.sequence_len - prompt_len # 10
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 11,
|
||||
"rejected": "r" * 9,
|
||||
} # 10+11=21 > 20, 10+9=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
|
||||
|
||||
# Add similar tests for KTO if needed, checking prompt + completion length
|
||||
|
||||
def test_kto_drop_mode_valid(self):
|
||||
"""Test KTO drop mode with a valid sample."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_kto_drop_mode_invalid(self):
|
||||
"""Test KTO drop mode with an invalid sample."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_kto_truncate_mode_no_truncation_needed(self):
|
||||
"""Test KTO truncate mode when no truncation is needed."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_kto_truncate_mode_prompt_too_long(self):
|
||||
"""Test KTO truncate mode when the prompt itself is too long."""
|
||||
sample = {"prompt": "p" * 25, "completion": "c" * 7}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample) # Returns original sample
|
||||
|
||||
def test_kto_truncate_mode_completion_truncated(self):
|
||||
"""Test KTO truncate mode when completion needs truncation."""
|
||||
prompt_len = 8
|
||||
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
|
||||
self.assertEqual(result["completion"], "x" * max_comp_len)
|
||||
|
||||
def test_missing_keys_dpo(self):
|
||||
"""Test ValueError raised if keys missing for DPO."""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt, chosen and rejected keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_missing_keys_kto(self):
|
||||
"""Test ValueError raised if keys missing for KTO."""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt and completion keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_unknown_rl_type(self):
|
||||
"""Test ValueError raised for unknown RL type."""
|
||||
sample = {}
|
||||
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
|
||||
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
|
||||
|
||||
# GRPO test - current implementation always passes
|
||||
def test_grpo_drop(self):
|
||||
"""Test GRPO drop mode (currently always True)."""
|
||||
sample = {}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_grpo_truncate(self):
|
||||
"""Test GRPO truncate mode (currently returns original sample)."""
|
||||
sample = {"a": 1}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, sample)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
153
tests/test_trainer_utils.py
Normal file
153
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Module containing tests for trainer utility functions."""
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
|
||||
# Test cases for truncate_or_drop_long_seq
|
||||
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
"""
|
||||
Test suite for truncate_or_drop_long_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""Test drop mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
|
||||
# Too short
|
||||
sample_short = {"input_ids": [1, 2]}
|
||||
self.assertFalse(handler(sample_short))
|
||||
|
||||
# Too long
|
||||
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||
self.assertFalse(handler(sample_long))
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||
self.assertTrue(handler(sample_ok))
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": []}
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""Test truncate mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
|
||||
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||
# but the function itself might return the sample or False based on impl.)
|
||||
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||
result_short = handler(sample_short)
|
||||
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||
|
||||
# Too long
|
||||
original_long = list(range(self.sequence_len + 5))
|
||||
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||
result_long = handler(sample_long)
|
||||
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||
|
||||
# Just right
|
||||
sample_ok = {
|
||||
"input_ids": list(range(self.min_sequence_len)),
|
||||
"labels": list(range(self.min_sequence_len)),
|
||||
}
|
||||
result_ok = handler(sample_ok)
|
||||
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": [], "labels": []}
|
||||
result_empty = handler(sample_empty)
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""Test drop mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 1)), # Too long
|
||||
list(range(self.sequence_len)), # OK (len = 10)
|
||||
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||
[], # Empty
|
||||
]
|
||||
}
|
||||
expected = [False, False, True, True, False]
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""Test truncate mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 5)), # Too long
|
||||
list(range(self.sequence_len)), # OK
|
||||
list(range(self.min_sequence_len)), # OK
|
||||
[], # Empty
|
||||
],
|
||||
"labels": [ # Add labels to test truncation
|
||||
[1, 2],
|
||||
list(range(self.sequence_len + 5)),
|
||||
list(range(self.sequence_len)),
|
||||
list(range(self.min_sequence_len)),
|
||||
[],
|
||||
],
|
||||
}
|
||||
|
||||
result = handler(sample)
|
||||
|
||||
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||
expected_input_ids = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
expected_labels = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
|
||||
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||
self.assertEqual(result["labels"], expected_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user