Compare commits

...

11 Commits

Author SHA1 Message Date
Dan Saunders
1defb8a955 Merge branch 'main' into destroy-pg 2025-03-31 14:36:43 -04:00
Dan Saunders
70b466aa67 ray bugfix 2025-03-31 18:35:41 +00:00
Dan Saunders
ef6eb77cc8 destroy process group on Ctrl+C / training or eval run (#2457)
* fix nccl pg destroy warning

* update
2025-03-31 12:36:47 -04:00
Dan Saunders
32ce167404 update 2025-03-31 14:46:15 +00:00
Dan Saunders
1c4cc639f5 fix nccl pg destroy warning 2025-03-31 14:32:50 +00:00
Dan Saunders
5410195e0b Sequence parallelism quick follow-ups; remove ModelCallback (#2450)
* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
2025-03-31 09:13:42 -04:00
NanoCode012
cf0c79d52e fix: minor patches for multimodal (#2441)
* fix: update chat_template

* fix: handle gemma3 showing a lot of no content for turn 0

* fix: remove unknown config from examples

* fix: test

* fix: temporary disable gemma2 test

* fix: stop overwriting config.text_config unnecessarily

* fix: handling of set cache to the text_config section

* feat: add liger gemma support and bump liger to 0.5.5

* fix: add double use_cache setting

* fix: add support for final_logit_softcap in CCE for gemma2/3

* fix: set use_cache before model load

* feat: add missing layernorm override

* fix: handle gemma3 rmsnorm

* fix: use wrapper to pass dim as hidden_size

* fix: change dim to positional

* fix: patch with wrong mlp

* chore: refactor use_cache handling

* fix import issues

* fix tests.e2e.utils import

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-31 13:40:12 +07:00
Wing Lian
4ba80a0e5a fix streaming packing test (#2454)
* fix streaming packing test

* constrain amount of text generated
2025-03-29 08:30:06 -04:00
Wing Lian
c49682132b use offline for precached stream dataset (#2453) 2025-03-28 23:39:09 -04:00
Wing Lian
e46239f8d3 bump liger to 0.5.5 (#2448) 2025-03-28 19:21:03 -04:00
Wing Lian
05f03b541a hf offline decorator for tests to workaround rate limits (#2452) [skip ci]
* hf offline decorator for tests to workaround rate limits

* fail quicker so we can see logs

* try new cache name

* limit files downloaded

* phi mini predownload

* offline decorator for phi tokenizer

* handle meta llama 8b offline too

* make sure to return fixtures if they are wrapped too

* more fixes

* more things offline

* more offline things

* fix the env var

* fix the model name

* handle gemma also

* force reload of modules to recheck offline status

* prefetch mistral too

* use reset_sessions so hub picks up offline mode

* more fixes

* rename so it doesn't seem like a context manager

* fix backoff

* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline

* include additional dataset

* more fixes

* more fixes

* replace tiny shakespeaere dataset

* skip some tests for now

* use more robust check using snapshot download to determine if a dataset name is on the hub

* typo for skip reason

* use local_files_only

* more fixtures

* remove local only

* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached

* make sure fixtures aren't offline

improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache

* now run all the ci again with hopefully a correct cache

* skip wonky tests for now

* skip wonky tests for now

* handle offline mode for model card creation
2025-03-28 19:20:46 -04:00
53 changed files with 815 additions and 291 deletions

View File

@@ -136,4 +136,4 @@ jobs:
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests

View File

@@ -63,7 +63,7 @@ jobs:
path: | path: |
/home/runner/.cache/huggingface/hub/datasets--* /home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--* /home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }} key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@@ -137,7 +137,7 @@ jobs:
path: | path: |
/home/runner/.cache/huggingface/hub/datasets--* /home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--* /home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }} key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@@ -171,6 +171,9 @@ jobs:
run: | run: |
axolotl --help axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -229,7 +232,7 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
@@ -276,4 +279,4 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests

View File

@@ -1,3 +1,4 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb,comet_ml known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -243,6 +243,7 @@ website:
- docs/unsloth.qmd - docs/unsloth.qmd
- docs/torchao.qmd - docs/torchao.qmd
- docs/custom_integrations.qmd - docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting" - section: "Troubleshooting"
contents: contents:

View File

@@ -658,6 +658,9 @@ ddp_broadcast_buffers:
# subsequences, or set to 4 to split into four equal-sized subsequences. # subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. # See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree: sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:

View File

@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended) - DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel) - FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA - FSDP + QLoRA
## DeepSpeed {#sec-deepspeed} ## DeepSpeed {#sec-deepspeed}
@@ -66,6 +67,28 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora} ### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml ```yaml
# Set to a divisor (> 1) of the number of GPUs available # Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
## Example ## Example
```yaml ```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192 sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
... ...
``` ```

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
strict: false strict: false
# huggingface repo # huggingface repo
chat_template: gemma3_text chat_template: gemma3
datasets: datasets:
- path: cgato/SlimOrcaDedupCleaned - path: cgato/SlimOrcaDedupCleaned
type: chat_template type: chat_template

View File

@@ -19,7 +19,6 @@ val_set_size: 0.0
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
dataset_exact_deduplication: true dataset_exact_deduplication: true
test_value: true
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.3 liger-kernel==0.5.5
# END section # END section
packaging==23.2 packaging==23.2
@@ -15,7 +15,7 @@ peft==0.15.0
transformers==4.50.0 transformers==4.50.0
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.5.2 accelerate==1.5.2
datasets==3.4.1 datasets==3.5.0
deepspeed==0.16.4 deepspeed==0.16.4
trl==0.15.1 trl==0.15.1

View File

@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps: if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks return callbacks
@@ -937,7 +935,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks return callbacks

View File

@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model del model
del tokenizer del tokenizer
cleanup_distributed()
return all_metrics return all_metrics

View File

@@ -25,8 +25,8 @@ import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")

View File

@@ -15,7 +15,6 @@ import transformers
from cut_cross_entropy.transformers.utils import ( from cut_cross_entropy.transformers.utils import (
PatchOptions, PatchOptions,
TransformersModelT, TransformersModelT,
apply_lce,
) )
from torch import nn from torch import nn
from transformers.cache_utils import Cache, HybridCache from transformers.cache_utils import Cache, HybridCache
@@ -33,6 +32,8 @@ from transformers.utils import (
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@@ -134,25 +135,17 @@ def cce_forward(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None assert labels is not None
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
)
loss = apply_lce( loss = apply_lce(
hidden_states[:, slice_indices, :], hidden_states[:, slice_indices, :],
self.lm_head.weight, self.lm_head.weight,
labels, labels,
_PATCH_OPTS, _PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**loss_kwargs, **loss_kwargs,
) )
elif _PATCH_OPTS is not None and defer_logits_calculation: elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward # defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :] logits = hidden_states[:, slice_indices, :]
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
)
else: else:
logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None: if self.config.final_logit_softcapping is not None:
@@ -353,6 +346,7 @@ def cce_forward_multimodal(
self.language_model.lm_head.weight, self.language_model.lm_head.weight,
labels, labels,
_PATCH_OPTS, _PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**lm_kwargs, **lm_kwargs,
) )
else: else:

View File

@@ -0,0 +1,40 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Monkeypatch for apply_lce to add softcap."""
import torch
from cut_cross_entropy import linear_cross_entropy
from cut_cross_entropy.transformers.utils import PatchOptions
def apply_lce(
e: torch.Tensor,
c: torch.Tensor,
labels: torch.Tensor,
opts: PatchOptions,
bias: torch.Tensor | None = None,
softcap: float | None = None,
**loss_kwargs,
) -> torch.Tensor:
"""Monkey patch for apply_lce to support softcap kwarg."""
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
cce_kwargs = opts.to_kwargs()
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
cce_kwargs["reduction"] = "sum"
else:
num_items_in_batch = None
loss = linear_cross_entropy(
e,
c,
labels.to(e.device),
bias=bias,
shift=True,
softcap=softcap,
**cce_kwargs,
)
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return loss

View File

@@ -20,6 +20,26 @@ liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
``` ```
## Supported Models
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- granite
- jamba
- llama
- mistral
- mixtral
- mllama
- mllama_text_model
- olmo2
- paligemma
- phi3
- qwen2
- qwen2_5_vl
- qwen2_vl
## Citation ## Citation
```bib ```bib

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn) liger_fn_sig = inspect.signature(apply_liger_fn)
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
modeling_jamba.JambaRMSNorm = LigerRMSNorm modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn from transformers.loss.loss_utils import nn
@@ -104,15 +114,51 @@ class LigerPlugin(BasePlugin):
# The DeepseekV2 version of RoPE is different than upstream LLaMA. # The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.") logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method. # nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]: elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int): def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
Args: Args:
sequence_parallel_degree: Sequence parallelism factor. sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
""" """
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info( LOG.info(
"Enabling ring attention sequence parallelism: " "Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
if rank == 0: if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}") LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree) substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)

View File

@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if turn_idx >= len(turns): if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range") raise ValueError(f"Turn index {turn_idx} out of range")
# mistral does not output message if it contains only system message # mistral/gemma3 does not output message if it contains only system message
if ( if (
turn_idx == 0 turn_idx == 0
and turns[0].get("role") == "system" and turns[0].get("role") == "system"
and "mistral" in self.tokenizer.name_or_path.lower() and (
"mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer
or "gemma" in self.tokenizer.name_or_path.lower()
)
): ):
return -1, -1 return -1, -1

View File

@@ -14,6 +14,7 @@ import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -26,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -156,6 +158,8 @@ def setup_signal_handler(
_model.save_pretrained( _model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
cleanup_distributed()
sys.exit(0) sys.exit(0)
_model_weakref = weakref.ref(model) _model_weakref = weakref.ref(model)
@@ -302,7 +306,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
model_card_kwarg["dataset_tags"] = dataset_tags model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated # Defensively push to the hub to ensure the model card is updated
@@ -477,7 +481,7 @@ def train(
Returns: Returns:
Tuple of (model, tokenizer) after training Tuple of (model, tokenizer) after training
""" """
# Setup model, tokenizer, (causal or RLHF) trainer etc. # Setup model, tokenizer, (causal or RLHF) trainer, etc.
( (
trainer, trainer,
model, model,
@@ -486,34 +490,26 @@ def train(
processor, processor,
) = setup_model_and_trainer(cfg, dataset_meta) ) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured # Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix( handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization cfg, model, tokenizer, train_dataset, safe_serialization
) )
# Save initial configs # Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor) save_initial_configs(cfg, tokenizer, model, peft_config, processor)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization) setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg) setup_model_card(cfg)
# Execute the training # Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint) execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model # Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization) save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer) create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer return model, tokenizer, trainer

View File

@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback): class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache""" """Callback to garbage collect torch cache"""

View File

@@ -6,8 +6,12 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import HFValidationError from huggingface_hub.errors import (
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -70,20 +74,25 @@ def load_dataset_w_config(
# pylint: disable=invalid-name # pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try: try:
# this is just a basic check to see if the path is a # this is just a basic check to see if the path is a
# valid HF dataset that's loadable # valid HF dataset that's loadable
load_dataset( snapshot_download(
config_dataset.path, repo_id=config_dataset.path,
name=config_dataset.name, repo_type="dataset",
streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision, revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code, ignore_patterns=["*"],
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (
RepositoryNotFoundError,
RevisionNotFoundError,
FileNotFoundError,
ConnectionError,
HFValidationError,
ValueError,
):
pass pass
ds_from_cloud = False ds_from_cloud = False

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process(): def is_main_process():
""" """
Check if the current process is the main process. Check if the current process is the main process. If not in distributed mode,
If not in distributed mode, always return True. always return `True`.
""" """
if not is_distributed(): if not is_distributed():
return True return True
@@ -87,6 +87,18 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager @contextmanager
def zero_only(): def zero_only():
""" """

View File

@@ -8,7 +8,7 @@ import math
import os import os
import types import types
from functools import cached_property from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
@@ -25,7 +25,7 @@ from peft import (
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from torch import nn from torch import nn
from transformers import ( # noqa: F401 from transformers import (
AddedToken, AddedToken,
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -39,6 +39,7 @@ from transformers import ( # noqa: F401
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration, MllamaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
ProcessorMixin, ProcessorMixin,
@@ -107,14 +108,21 @@ def get_module_class_from_name(module, name):
return None return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
# Set use_cache to False
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
if cfg.is_multimodal: if cfg.is_multimodal:
if hasattr(model_config, "text_config"): # For multimodal configs, use_cache is set in the text_config
model_config = model_config.text_config if hasattr(model_config, "get_text_config"):
model_config.use_cache = False text_config = model_config.get_text_config()
elif hasattr(model_config, "get_text_config"): if hasattr(text_config, "use_cache"):
model_config = model_config.get_text_config() text_config.use_cache = False
model_config.use_cache = False else:
raise ValueError(
"No text config found for multimodal model. Please raise an Issue with model details."
)
# check if image_size is not set and load image size from model config if available # check if image_size is not set and load image size from model config if available
if ( if (
@@ -523,14 +531,6 @@ class ModelLoader:
# init model config # init model config
self.model_config = load_model_config(cfg) self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
if hasattr(self.model_config, "text_config"):
self.text_model_config = self.model_config.text_config
else:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
else:
self.text_model_config = self.model_config
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
@@ -609,7 +609,10 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after # Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash # model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs. # attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree) register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
@@ -947,8 +950,6 @@ class ModelLoader:
quantization_config = ( quantization_config = (
quantization_config or self.model_kwargs["quantization_config"] quantization_config or self.model_kwargs["quantization_config"]
) )
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = load_sharded_model_quant( self.model = load_sharded_model_quant(
self.base_model, self.base_model,
self.model_config, self.model_config,
@@ -969,9 +970,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading() _ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified # Load model with random initialization if specified
if self.cfg.random_init_weights: if self.cfg.random_init_weights:
# AutoModel classes support the from_config method # AutoModel classes support the from_config method
@@ -1026,8 +1024,6 @@ class ModelLoader:
and self.model_type != "AutoModelForCausalLM" and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
): ):
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq: if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
@@ -1043,25 +1039,7 @@ class ModelLoader:
**self.model_kwargs, **self.model_kwargs,
) )
else: else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(self.text_model_config, "max_seq_len")
and self.text_model_config.max_seq_len
and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
self.text_model_config.max_seq_len = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
hasattr(self.text_model_config, "max_sequence_length")
and self.text_model_config.max_sequence_length
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
self.text_model_config.max_sequence_length = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
if self.cfg.gptq: if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1080,8 +1058,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading() _ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1346,8 +1322,6 @@ class ModelLoader:
requires_grad.append(f"{name}: {param.requires_grad}") requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0: if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates") LOG.warning("there are no parameters that require gradient updates")
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if self.cfg.flash_optimum: if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer

View File

@@ -248,6 +248,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0) val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None tokens: list[str] | None = None
@@ -1108,7 +1109,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before") @field_validator("sequence_parallel_degree", mode="before")
@classmethod @classmethod
def check_sequence_parallel_config(cls, value, info): def check_sequence_parallel_degree(cls, value, info):
if not value: if not value:
value = 1 value = 1

0
tests/__init__.py Normal file
View File

View File

@@ -11,7 +11,11 @@ import time
import pytest import pytest
import requests import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1): def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -25,9 +29,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
except ( except (
requests.exceptions.ReadTimeout, requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
) as exc: ) as exc:
if attempt < max_retries - 1: if attempt < max_retries - 1:
time.sleep(delay) wait = 2**attempt * delay # in seconds
time.sleep(wait)
else: else:
raise exc raise exc
@@ -37,6 +43,7 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
@disable_hf_offline
def snapshot_download_w_retry(*args, **kwargs): def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs) return snapshot_download(*args, **kwargs)
@@ -44,19 +51,19 @@ def snapshot_download_w_retry(*args, **kwargs):
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model(): def download_smollm2_135m_model():
# download the model # download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M") snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model(): def download_llama_68m_random_model():
# download the model # download the model
snapshot_download_w_retry("JackFram/llama-68m") snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model(): def download_qwen_2_5_half_billion_model():
# download the model # download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B") snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -101,6 +108,37 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
) )
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset # download the dataset
@@ -109,10 +147,141 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
) )
@pytest.fixture(scope="session", autouse=True)
def download_argilla_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset(): def download_tiny_shakespeare_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"huggyllama/llama-7b",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-3.2-1B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"casperhansen/mistral-7b-instruct-v0.1-awq",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma_2b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"unsloth/gemma-2b-it",
revision="703fb4a",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma2_9b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/gemma-2-9b-it-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mlx_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-2-7b-hf",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture @pytest.fixture
@@ -178,3 +347,34 @@ def cleanup_monkeypatches():
module_globals = module_name_tuple[1] module_globals = module_name_tuple[1]
for module_global in module_globals: for module_global in module_globals:
globals().pop(module_global, None) globals().pop(module_global, None)
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass

View File

@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats from axolotl.core.chat.messages import ChatFormattedChats, Chats
from tests.hf_offline_utils import enable_hf_offline # noqa
@pytest.fixture(scope="session", name="llama_tokenizer") @pytest.fixture(scope="session", name="llama_tokenizer")
@enable_hf_offline
def llama_tokenizer_fixture(): def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B") return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer") @pytest.fixture(scope="session", name="chatml_tokenizer")

View File

@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path from pathlib import Path
import pytest import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -13,6 +12,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg") @pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir): def min_cfg(temp_dir):

View File

@@ -2,15 +2,13 @@
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
class LigerIntegrationTestCase: class LigerIntegrationTestCase:

View File

@@ -8,11 +8,12 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO: class TestGRPO:
""" """

View File

@@ -9,12 +9,13 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -9,10 +9,11 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4 # Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4) register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
# Verify the number of calls without examining the arguments # Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2 assert mock_new_group.call_count == 2

View File

@@ -14,6 +14,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -23,6 +25,7 @@ class TestDeepseekV3:
Test case for DeepseekV3 models Test case for DeepseekV3 models
""" """
@enable_hf_offline
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sample_packing", "sample_packing",
[True, False], [True, False],
@@ -80,6 +83,7 @@ class TestDeepseekV3:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@enable_hf_offline
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sample_packing", "sample_packing",
[True, False], [True, False],

View File

@@ -5,14 +5,14 @@ E2E tests for llama
import logging import logging
import os import os
from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

85
tests/hf_offline_utils.py Normal file
View File

@@ -0,0 +1,85 @@
"""
test utils for helpers and decorators
"""
import os
from functools import wraps
from huggingface_hub.utils import reset_sessions
def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
import importlib
import datasets
import huggingface_hub.constants
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()
def enable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "1"
reload_modules(True)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper
def disable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "0"
reload_modules(False)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper

View File

@@ -4,12 +4,13 @@ shared fixtures for prompt strategies tests
import pytest import pytest
from datasets import Dataset from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES from axolotl.utils.chat_templates import _CHAT_TEMPLATES
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
@@ -108,31 +109,27 @@ def fixture_toolcalling_dataset():
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True) @pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
def fixture_llama3_tokenizer(): @enable_hf_offline
hf_hub_download( def fixture_llama3_tokenizer(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", download_llama3_8b_instruct_model_fixture,
filename="special_tokens_map.json", ): # pylint: disable=unused-argument,redefined-outer-name
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="tokenizer_config.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True) @pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_smollm2_tokenizer(): def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer(): @enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit" "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
) )
@@ -140,6 +137,7 @@ def fixture_mistralv03_tokenizer():
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True) @pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_phi35_tokenizer(): def fixture_phi35_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
return tokenizer return tokenizer

View File

@@ -11,6 +11,8 @@ from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="alpaca_dataset") @pytest.fixture(name="alpaca_dataset")
def fixture_alpaca_dataset(): def fixture_alpaca_dataset():
@@ -26,6 +28,7 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
@enable_hf_offline
def fixture_tokenizer(): def fixture_tokenizer():
# pylint: disable=all # pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -13,8 +13,11 @@ from axolotl.utils.chat_templates import (
get_chat_template, get_chat_template,
) )
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="llama3_tokenizer") @pytest.fixture(name="llama3_tokenizer")
@enable_hf_offline
def fixture_llama3_tokenizer(): def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")

View File

@@ -17,6 +17,8 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import get_chat_template
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -30,12 +32,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja", "mistralv03_tokenizer_chat_template_jinja",
"[/INST]", "[/INST]",
), ),
( # TODO: temporarily skip gemma due to gemma3 template
"gemma2_tokenizer", # Re-enable on new chat_template implementation for perf
"jinja", # (
"gemma2_tokenizer_chat_template_jinja", # "gemma2_tokenizer",
"<end_of_turn>", # "jinja",
), # "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
("phi35_tokenizer", "phi_35", None, "<|end|>"), ("phi35_tokenizer", "phi_35", None, "<|end|>"),
] ]
@@ -93,7 +97,11 @@ class TestChatTemplateConfigurations:
if ( if (
turn_idx == 0 turn_idx == 0
and turn.get("from") in ["system", "context"] and turn.get("from") in ["system", "context"]
and "mistral" in tokenizer.name_or_path.lower() and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
): ):
assert ( assert (
start_idx == -1 and end_idx == -1 start_idx == -1 and end_idx == -1
@@ -101,6 +109,7 @@ class TestChatTemplateConfigurations:
return True return True
return False return False
@enable_hf_offline
def test_train_on_inputs_true( def test_train_on_inputs_true(
self, self,
tokenizer, tokenizer,

View File

@@ -11,6 +11,8 @@ from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
@@ -78,15 +80,8 @@ def fixture_custom_assistant_dataset():
) )
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
@pytest.fixture(name="phi3_tokenizer") @pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer(): def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct") tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
@@ -94,6 +89,7 @@ def fixture_phi3_tokenizer():
@pytest.fixture(name="gemma_tokenizer") @pytest.fixture(name="gemma_tokenizer")
@enable_hf_offline
def fixture_gemma_tokenizer(): def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a") tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")

View File

@@ -10,6 +10,8 @@ from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="minimal_dpo_cfg") @pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg(): def fixture_cfg():
@@ -34,6 +36,8 @@ class TestDPOChatml:
Test loading DPO preference datasets with chatml formatting Test loading DPO preference datasets with chatml formatting
""" """
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_default(self, minimal_dpo_cfg): def test_default(self, minimal_dpo_cfg):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -8,12 +8,15 @@ from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5 from axolotl.utils.data import encode_pretraining, md5
from tests.hf_offline_utils import enable_hf_offline
class TestEncodePretraining(unittest.TestCase): class TestEncodePretraining(unittest.TestCase):
""" """
test class for encode pretraining and md5 helper test class for encode pretraining and md5 helper
""" """
@enable_hf_offline
def setUp(self): def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens( self.tokenizer.add_special_tokens(

View File

@@ -4,31 +4,37 @@ Test dataset loading under various conditions.
import shutil import shutil
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from conftest import snapshot_download_w_retry import pytest
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from tests.hf_offline_utils import enable_hf_offline
class TestDatasetPreparation(unittest.TestCase):
class TestDatasetPreparation:
"""Test a configured dataloader.""" """Test a configured dataloader."""
def setUp(self) -> None: @pytest.fixture
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
# Alpaca dataset. yield tokenizer_huggyllama
self.dataset = Dataset.from_list(
@pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
[ [
{ {
"instruction": "Evaluate this sentence for spelling and grammar mistakes", "instruction": "Evaluate this sentence for spelling and grammar mistakes",
@@ -38,7 +44,9 @@ class TestDatasetPreparation(unittest.TestCase):
] ]
) )
def test_load_hub(self): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
"""Core use case. Verify that processing data from the hub works""" """Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -55,25 +63,28 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_local_hub(self): @enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self, tokenizer):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded""" """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
) )
# offline mode doesn't actually copy it to local_dir, so we
# have to copy all the contents in the dir manually from the returned snapshot_path
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
# Right now a local copy that doesn't fully conform to a dataset # Right now a local copy that doesn't fully conform to a dataset
@@ -96,9 +107,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -106,11 +115,12 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
def test_load_from_save_to_disk(self): @enable_hf_offline
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset" tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name)) dataset_fixture.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -126,22 +136,21 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_dir_of_parquet(self): @enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded.""" """Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet" tmp_ds_path = tmp_ds_dir / "shard1.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -162,22 +171,21 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_dir_of_json(self): @enable_hf_offline
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a directory of json files can be loaded.""" """Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json" tmp_ds_path = tmp_ds_dir / "shard1.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -198,20 +206,19 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_single_parquet(self): @enable_hf_offline
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single parquet file can be loaded.""" """Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -228,20 +235,19 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_single_json(self): @enable_hf_offline
def test_load_from_single_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single json file can be loaded.""" """Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -258,15 +264,15 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@enable_hf_offline
def test_load_hub_with_dpo(self): def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works""" """Verify that processing dpo data from the hub works"""
@@ -285,7 +291,9 @@ class TestDatasetPreparation(unittest.TestCase):
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub_with_revision(self, tokenizer):
"""Verify that processing data from the hub works with a specific revision""" """Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -307,16 +315,17 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self): @enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision""" """Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault( cfg = DictDefault(
@@ -329,22 +338,34 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) # pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
assert len(train_dataset) == 1800 train_dataset, _ = load_prepare_preference_datasets(cfg)
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self): assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self, tokenizer):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision""" """Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
revision="d05c1cb", revision="d05c1cb",
) )
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -365,9 +386,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -375,17 +394,19 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
def test_loading_local_dataset_folder(self): @enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):
"""Verify that a dataset downloaded to a local folder can be loaded""" """Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
) )
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -401,16 +422,10 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,9 +8,8 @@ import hashlib
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS import pytest
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
@@ -19,6 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name): def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
""" """
@@ -214,13 +216,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase): class TestDeduplicateRLDataset:
"""Test a configured dataloader with deduplication.""" """Test a configured dataloader with deduplication."""
def setUp(self) -> None: @pytest.fixture
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") def cfg(self):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) fixture = DictDefault(
self.cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024, "sequence_len": 1024,
@@ -233,34 +234,66 @@ class TestDeduplicateRLDataset(unittest.TestCase):
], ],
} }
) )
yield fixture
def test_load_with_deduplication(self): @enable_hf_offline
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # pylint: disable=duplicate-code
train_dataset, _ = load_prepare_preference_datasets(self.cfg) with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Verify that the dataset has been deduplicated train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
def test_load_without_deduplication(self): # Verify that the dataset has been deduplicated
"""Verify that loading without deduplication retains duplicates.""" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates @enable_hf_offline
assert ( def test_load_without_deduplication(
len(train_dataset) == 1800 * 2 self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
), "Dataset deduplication occurred when it should not have" ):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase): class TestDeduplicateNonRL(unittest.TestCase):
"""Test prepare_dataset function with different configurations.""" """Test prepare_dataset function with different configurations."""
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault( self.cfg_1 = DictDefault(
{ {
"base_model": "huggyllama/llama-7b", "base_model": "huggyllama/llama-7b",
@@ -286,6 +319,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
) )
normalize_config(self.cfg_1) normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_train(self): def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication.""" """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True self.cfg_1.dataset_exact_deduplication = True
@@ -311,6 +346,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Train dataset should have 2000 samples after deduplication.", "Train dataset should have 2000 samples after deduplication.",
) )
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_eval(self): def test_prepare_dataset_with_deduplication_eval(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication.""" """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True self.cfg_1.dataset_exact_deduplication = True
@@ -336,6 +373,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Eval dataset should have 2000 samples after deduplication.", "Eval dataset should have 2000 samples after deduplication.",
) )
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_without_deduplication(self): def test_prepare_dataset_without_deduplication(self):
"""Verify that prepare_dataset function processes the dataset correctly without deduplication.""" """Verify that prepare_dataset function processes the dataset correctly without deduplication."""
self.cfg_1.dataset_exact_deduplication = False self.cfg_1.dataset_exact_deduplication = False

View File

@@ -12,6 +12,8 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
@@ -25,6 +27,7 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, num_workers", "batch_size, num_workers",
[ [
@@ -35,11 +38,12 @@ class TestBatchedSamplerPacking:
], ],
) )
@pytest.mark.parametrize("max_seq_length", [4096, 512]) @pytest.mark.parametrize("max_seq_length", [4096, 512])
@enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset( dataset = load_dataset(
"Trelis/tiny-shakespeare", "winglian/tiny-shakespeare",
split="train", split="train",
) )

View File

@@ -10,12 +10,15 @@ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter from axolotl.prompters import AlpacaPrompter
from tests.hf_offline_utils import enable_hf_offline
class TestPacking(unittest.TestCase): class TestPacking(unittest.TestCase):
""" """
Test class for packing dataset sequences Test class for packing dataset sequences
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")

View File

@@ -1,43 +1,60 @@
"""Module for testing streaming dataset sequence packing""" """Module for testing streaming dataset sequence packing"""
import functools import functools
import unittest import random
import string
import pytest import pytest
import torch import torch
from datasets import load_dataset from datasets import IterableDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase): class TestPretrainingPacking:
""" """
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
def setUp(self) -> None: @pytest.fixture
# pylint: disable=duplicate-code def random_text(self):
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") # seed with random.seed(0) for reproducibility
self.tokenizer.pad_token = "</s>" random.seed(0)
@pytest.mark.flaky(retries=3, delay=5) # generate row of random text with "words" of between 2 and 10 characters and
def test_packing_stream_dataset(self): # between 400 to 1200 characters per line
# pylint: disable=duplicate-code def rand_txt():
dataset = load_dataset( return " ".join(
"allenai/c4", [
"en", "".join(
streaming=True, random.choices(string.ascii_lowercase, k=random.randint(2, 10))
)["train"] )
for _ in range(random.randint(50, 200))
]
)
# Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset
def generator():
for row in data:
yield {"text": row}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
cfg = DictDefault( cfg = DictDefault(
{ {
"pretraining_dataset": [ "pretraining_dataset": [
{ {
"path": "allenai/c4", "path": "winglian/tiny-shakespeare",
"name": "en",
"type": "pretrain", "type": "pretrain",
} }
], ],
@@ -54,15 +71,16 @@ class TestPretrainingPacking(unittest.TestCase):
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
dataset, dataset,
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
@@ -78,7 +96,7 @@ class TestPretrainingPacking(unittest.TestCase):
) )
idx = 0 idx = 0
for data in trainer_loader: for data in trainer_loader:
if idx > 10: if idx > 3:
break break
assert data["input_ids"].shape == torch.Size( assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len] [1, original_bsz * cfg.sequence_len]
@@ -95,7 +113,3 @@ class TestPretrainingPacking(unittest.TestCase):
# [1, original_bsz * cfg.sequence_len] # [1, original_bsz * cfg.sequence_len]
# ) # )
idx += 1 idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -5,6 +5,7 @@ import logging
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from datasets import load_dataset from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -22,6 +23,8 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
test_data = { test_data = {
@@ -63,6 +66,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies. Test class for prompt tokenization strategies.
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -119,6 +123,7 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -160,6 +165,7 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
@@ -238,6 +244,7 @@ If a question does not make any sense, or is not factually coherent, explain why
class OrpoTokenizationTest(unittest.TestCase): class OrpoTokenizationTest(unittest.TestCase):
"""test case for the ORPO tokenization""" """test case for the ORPO tokenization"""
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
@@ -262,6 +269,7 @@ class OrpoTokenizationTest(unittest.TestCase):
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train" "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0]) ).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self): def test_orpo_integration(self):
strat = load( strat = load(
self.tokenizer, self.tokenizer,

View File

@@ -9,12 +9,15 @@ import pytest
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers: class TestTokenizers:
""" """
test class for the load_tokenizer fn test class for the load_tokenizer fn
""" """
@enable_hf_offline
def test_default_use_fast(self): def test_default_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -24,6 +27,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__ assert "Fast" in tokenizer.__class__.__name__
@enable_hf_offline
def test_dont_use_fast(self): def test_dont_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -34,6 +38,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__
@enable_hf_offline
def test_special_tokens_modules_to_save(self): def test_special_tokens_modules_to_save(self):
# setting special_tokens to new token # setting special_tokens to new token
cfg = DictDefault( cfg = DictDefault(
@@ -68,6 +73,7 @@ class TestTokenizers:
) )
load_tokenizer(cfg) load_tokenizer(cfg)
@enable_hf_offline
def test_add_additional_special_tokens(self): def test_add_additional_special_tokens(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -83,6 +89,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
@enable_hf_offline
def test_added_tokens_overrides(self, temp_dir): def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -104,11 +111,12 @@ class TestTokenizers:
128042 128042
] ]
@enable_hf_offline
def test_added_tokens_overrides_with_toolargeid(self, temp_dir): def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
# use with tokenizer that has reserved_tokens in added_tokens # use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B", "tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"}, "added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir, "output_dir": temp_dir,
} }

0
tests/utils/__init__.py Normal file
View File