Compare commits

..

7 Commits

Author SHA1 Message Date
NanoCode012
7888a35118 chore: remove unused log 2025-03-31 16:20:15 +07:00
NanoCode012
873385b7d5 feat: update xformers for new attention interface 2025-03-31 16:15:55 +07:00
NanoCode012
cf0c79d52e fix: minor patches for multimodal (#2441)
* fix: update chat_template

* fix: handle gemma3 showing a lot of no content for turn 0

* fix: remove unknown config from examples

* fix: test

* fix: temporary disable gemma2 test

* fix: stop overwriting config.text_config unnecessarily

* fix: handling of set cache to the text_config section

* feat: add liger gemma support and bump liger to 0.5.5

* fix: add double use_cache setting

* fix: add support for final_logit_softcap in CCE for gemma2/3

* fix: set use_cache before model load

* feat: add missing layernorm override

* fix: handle gemma3 rmsnorm

* fix: use wrapper to pass dim as hidden_size

* fix: change dim to positional

* fix: patch with wrong mlp

* chore: refactor use_cache handling

* fix import issues

* fix tests.e2e.utils import

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-31 13:40:12 +07:00
Wing Lian
4ba80a0e5a fix streaming packing test (#2454)
* fix streaming packing test

* constrain amount of text generated
2025-03-29 08:30:06 -04:00
Wing Lian
c49682132b use offline for precached stream dataset (#2453) 2025-03-28 23:39:09 -04:00
Wing Lian
e46239f8d3 bump liger to 0.5.5 (#2448) 2025-03-28 19:21:03 -04:00
Wing Lian
05f03b541a hf offline decorator for tests to workaround rate limits (#2452) [skip ci]
* hf offline decorator for tests to workaround rate limits

* fail quicker so we can see logs

* try new cache name

* limit files downloaded

* phi mini predownload

* offline decorator for phi tokenizer

* handle meta llama 8b offline too

* make sure to return fixtures if they are wrapped too

* more fixes

* more things offline

* more offline things

* fix the env var

* fix the model name

* handle gemma also

* force reload of modules to recheck offline status

* prefetch mistral too

* use reset_sessions so hub picks up offline mode

* more fixes

* rename so it doesn't seem like a context manager

* fix backoff

* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline

* include additional dataset

* more fixes

* more fixes

* replace tiny shakespeaere dataset

* skip some tests for now

* use more robust check using snapshot download to determine if a dataset name is on the hub

* typo for skip reason

* use local_files_only

* more fixtures

* remove local only

* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached

* make sure fixtures aren't offline

improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache

* now run all the ci again with hopefully a correct cache

* skip wonky tests for now

* skip wonky tests for now

* handle offline mode for model card creation
2025-03-28 19:20:46 -04:00
44 changed files with 876 additions and 462 deletions

View File

@@ -136,4 +136,4 @@ jobs:
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests

View File

@@ -63,7 +63,7 @@ jobs:
path: | path: |
/home/runner/.cache/huggingface/hub/datasets--* /home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--* /home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }} key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@@ -137,7 +137,7 @@ jobs:
path: | path: |
/home/runner/.cache/huggingface/hub/datasets--* /home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--* /home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }} key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@@ -171,6 +171,9 @@ jobs:
run: | run: |
axolotl --help axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -229,7 +232,7 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
@@ -276,4 +279,4 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.tests modal run cicd.e2e_tests

View File

@@ -1,3 +1,4 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb,comet_ml known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
strict: false strict: false
# huggingface repo # huggingface repo
chat_template: gemma3_text chat_template: gemma3
datasets: datasets:
- path: cgato/SlimOrcaDedupCleaned - path: cgato/SlimOrcaDedupCleaned
type: chat_template type: chat_template

View File

@@ -19,7 +19,6 @@ val_set_size: 0.0
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
dataset_exact_deduplication: true dataset_exact_deduplication: true
test_value: true
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.3 liger-kernel==0.5.5
# END section # END section
packaging==23.2 packaging==23.2
@@ -15,7 +15,7 @@ peft==0.15.0
transformers==4.50.0 transformers==4.50.0
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.5.2 accelerate==1.5.2
datasets==3.4.1 datasets==3.5.0
deepspeed==0.16.4 deepspeed==0.16.4
trl==0.15.1 trl==0.15.1

View File

@@ -25,8 +25,8 @@ import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")

View File

@@ -15,7 +15,6 @@ import transformers
from cut_cross_entropy.transformers.utils import ( from cut_cross_entropy.transformers.utils import (
PatchOptions, PatchOptions,
TransformersModelT, TransformersModelT,
apply_lce,
) )
from torch import nn from torch import nn
from transformers.cache_utils import Cache, HybridCache from transformers.cache_utils import Cache, HybridCache
@@ -33,6 +32,8 @@ from transformers.utils import (
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@@ -134,25 +135,17 @@ def cce_forward(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None assert labels is not None
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
)
loss = apply_lce( loss = apply_lce(
hidden_states[:, slice_indices, :], hidden_states[:, slice_indices, :],
self.lm_head.weight, self.lm_head.weight,
labels, labels,
_PATCH_OPTS, _PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**loss_kwargs, **loss_kwargs,
) )
elif _PATCH_OPTS is not None and defer_logits_calculation: elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward # defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :] logits = hidden_states[:, slice_indices, :]
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
)
else: else:
logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None: if self.config.final_logit_softcapping is not None:
@@ -353,6 +346,7 @@ def cce_forward_multimodal(
self.language_model.lm_head.weight, self.language_model.lm_head.weight,
labels, labels,
_PATCH_OPTS, _PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**lm_kwargs, **lm_kwargs,
) )
else: else:

View File

@@ -0,0 +1,40 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Monkeypatch for apply_lce to add softcap."""
import torch
from cut_cross_entropy import linear_cross_entropy
from cut_cross_entropy.transformers.utils import PatchOptions
def apply_lce(
e: torch.Tensor,
c: torch.Tensor,
labels: torch.Tensor,
opts: PatchOptions,
bias: torch.Tensor | None = None,
softcap: float | None = None,
**loss_kwargs,
) -> torch.Tensor:
"""Monkey patch for apply_lce to support softcap kwarg."""
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
cce_kwargs = opts.to_kwargs()
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
cce_kwargs["reduction"] = "sum"
else:
num_items_in_batch = None
loss = linear_cross_entropy(
e,
c,
labels.to(e.device),
bias=bias,
shift=True,
softcap=softcap,
**cce_kwargs,
)
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return loss

View File

@@ -20,6 +20,26 @@ liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
``` ```
## Supported Models
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- granite
- jamba
- llama
- mistral
- mixtral
- mllama
- mllama_text_model
- olmo2
- paligemma
- phi3
- qwen2
- qwen2_5_vl
- qwen2_vl
## Citation ## Citation
```bib ```bib

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn) liger_fn_sig = inspect.signature(apply_liger_fn)
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
modeling_jamba.JambaRMSNorm = LigerRMSNorm modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn from transformers.loss.loss_utils import nn
@@ -104,15 +114,51 @@ class LigerPlugin(BasePlugin):
# The DeepseekV2 version of RoPE is different than upstream LLaMA. # The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.") logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method. # nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]: elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -1,153 +1,113 @@
""" """
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments Hijack the LlamaAttention forward method to use xformers if available.
Updated for transformers v4.50.0.
""" """
import logging from typing import Optional
import warnings
from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F from torch import nn
import transformers.models.llama.modeling_llama from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try: try:
import xformers.ops import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError: except ImportError:
logging.error("xformers not found! Please install it before trying to use it.") XFORMERS_AVAILABLE = False
def xformers_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs, # pylint: disable=unused-argument
):
"""
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
Args:
module: The LlamaAttention module
query: Query states of shape [batch, num_heads, seq_len, head_dim]
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
attention_mask: Attention mask
scaling: Scaling factor for attention scores
dropout: Dropout probability
Returns:
attn_output: Output of xformers memory-efficient attention
attn_weights: None
"""
# First, handle grouped-query attention (GQA)
# We need to repeat key and value states to match the number of query heads
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Determine if we need a causal mask
is_causal = getattr(module, "is_causal", True)
# Set up the attention bias for xformers
if is_causal:
# Use xformers built-in causal mask
attn_bias = xformers.ops.LowerTriangularMask()
elif attention_mask is not None:
# For non-causal attention with a mask, we'd need to convert the mask
# This is a simplification - you might need to adapt based on your mask format
attn_bias = attention_mask
else:
# No mask needed
attn_bias = None
# Apply xformers memory-efficient attention
attn_output = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attn_bias,
p=dropout if module.training else 0.0,
scale=scaling,
)
# Reshape back to [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2)
return attn_output, None # Return None for attn_weights to match interface
def hijack_llama_attention(): def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward """
Patch the LlamaAttention forward method to use xformers if available.
"""
def xformers_forward( if not XFORMERS_AVAILABLE:
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" "xformers not available. Please install it following axolotl's requirements."
f" {attn_output.size()}"
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# import transformers.models.llama.modeling_llama as llama_modeling
# xformers-attn end
#
if self.pretraining_tp > 1: # Add xformers to the available attention implementations
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value # Create a wrapper for the original LlamaAttention forward method
original_forward = llama_modeling.LlamaAttention.forward
def patched_forward(self, *args, **kwargs):
# Set the attention implementation to xformers
# pylint: disable=protected-access
self.config._attn_implementation = "xformers"
return original_forward(self, *args, **kwargs)
# Apply the patch
llama_modeling.LlamaAttention.forward = patched_forward

View File

@@ -1,6 +1,5 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
import ast
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
@@ -76,49 +75,6 @@ class ProcessingStrategy:
result["messages"] = messages result["messages"] = messages
return result return result
def convert_multiple_choice_to_multimedia_messages(
messages: dict,
) -> list[dict]:
def construct_prompt(sample):
question = sample["question"]
options = sample["options"]
if isinstance(options, str):
options = ast.literal_eval(options)
example = ""
start_chr = "A"
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
empty_prompt = empty_prompt_sample_structure.format(question, example)
return empty_prompt
new_messages = []
user_content = construct_prompt(messages)
assistant_response = messages["answer"]
new_messages.append(
{"role": "user", "content": [{"type": "text", "text": user_content}]}
)
new_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_response}],
}
)
return new_messages
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
"""Convert regular messages format to Messages format with content type""" """Convert regular messages format to Messages format with content type"""
@@ -150,51 +106,39 @@ class ProcessingStrategy:
processed_examples = [] processed_examples = []
for example in examples: for example in examples:
if not ( if not ("messages" in example or "conversations" in example):
"messages" in example
or "conversations" in example
or "question" in example
):
raise ValueError( raise ValueError(
"Only `messages`, `conversations`, and `question` message keys are currently supported." "Only `messages` and `conversations` message keys are currently supported."
) )
processed_example = None processed_example = None
if "messages" in example: # OpenAI format if "messages" in example: # OpenAI format
processed_example = example processed_example = example
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
elif "question" in example: # Multiple choice format
processed_example = {}
processed_example["messages"] = (
convert_multiple_choice_to_multimedia_messages(example)
)
else: # Legacy format else: # Legacy format
processed_example = convert_legacy_format(example) processed_example = convert_legacy_format(example)
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"] # convert regular messages format to Messages format with content type
) # for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
# find the image key if it exists # find the image key if it exists
possible_image_keys = ["images", "image"]
image_key = None
for key in possible_image_keys:
if key in processed_example:
image_key = key
break
image_keys = [] # if the image key exists, add the image to the first message
for key in example.keys(): if image_key is not None:
if "image" in key: # TODO: check if it's normal to be single image only for common datasets
image_keys.append(key) # From observation, it's usually a list of single image but some datasets may have several columns for images
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
for im_key in image_keys: image_value = processed_example[image_key][0]
if example[im_key] is None:
continue
if isinstance(example[im_key], list):
if len(example[im_key]) == 0:
continue
image_value = example[im_key][0]
else:
image_value = example[im_key]
# Handle image loading (Image, url, path, base64)
image_value = load_image(image_value) image_value = load_image(image_value)
if self.image_size is not None: if self.image_size is not None:
@@ -219,12 +163,33 @@ class ProcessingStrategy:
color=padding_color, color=padding_color,
) )
processed_example["messages"][0]["content"].append( # Look for any image type in the first message
{ # some dataset have an {type: "image"} in the first message
"type": "image", ind_to_add = None
"image": image_value,
} for i, content in enumerate(
) processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_examples.append(processed_example) processed_examples.append(processed_example)

View File

@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if turn_idx >= len(turns): if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range") raise ValueError(f"Turn index {turn_idx} out of range")
# mistral does not output message if it contains only system message # mistral/gemma3 does not output message if it contains only system message
if ( if (
turn_idx == 0 turn_idx == 0
and turns[0].get("role") == "system" and turns[0].get("role") == "system"
and "mistral" in self.tokenizer.name_or_path.lower() and (
"mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer
or "gemma" in self.tokenizer.name_or_path.lower()
)
): ):
return -1, -1 return -1, -1

View File

@@ -14,6 +14,7 @@ import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -302,7 +303,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
model_card_kwarg["dataset_tags"] = dataset_tags model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated # Defensively push to the hub to ensure the model card is updated

View File

@@ -6,8 +6,12 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import HFValidationError from huggingface_hub.errors import (
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -70,20 +74,25 @@ def load_dataset_w_config(
# pylint: disable=invalid-name # pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try: try:
# this is just a basic check to see if the path is a # this is just a basic check to see if the path is a
# valid HF dataset that's loadable # valid HF dataset that's loadable
load_dataset( snapshot_download(
config_dataset.path, repo_id=config_dataset.path,
name=config_dataset.name, repo_type="dataset",
streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision, revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code, ignore_patterns=["*"],
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (
RepositoryNotFoundError,
RevisionNotFoundError,
FileNotFoundError,
ConnectionError,
HFValidationError,
ValueError,
):
pass pass
ds_from_cloud = False ds_from_cloud = False

View File

@@ -8,7 +8,7 @@ import math
import os import os
import types import types
from functools import cached_property from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
@@ -25,7 +25,7 @@ from peft import (
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from torch import nn from torch import nn
from transformers import ( # noqa: F401 from transformers import (
AddedToken, AddedToken,
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -39,6 +39,7 @@ from transformers import ( # noqa: F401
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration, MllamaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
ProcessorMixin, ProcessorMixin,
@@ -107,14 +108,21 @@ def get_module_class_from_name(module, name):
return None return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
# Set use_cache to False
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
if cfg.is_multimodal: if cfg.is_multimodal:
if hasattr(model_config, "text_config"): # For multimodal configs, use_cache is set in the text_config
model_config = model_config.text_config if hasattr(model_config, "get_text_config"):
model_config.use_cache = False text_config = model_config.get_text_config()
elif hasattr(model_config, "get_text_config"): if hasattr(text_config, "use_cache"):
model_config = model_config.get_text_config() text_config.use_cache = False
model_config.use_cache = False else:
raise ValueError(
"No text config found for multimodal model. Please raise an Issue with model details."
)
# check if image_size is not set and load image size from model config if available # check if image_size is not set and load image size from model config if available
if ( if (
@@ -523,14 +531,6 @@ class ModelLoader:
# init model config # init model config
self.model_config = load_model_config(cfg) self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
if hasattr(self.model_config, "text_config"):
self.text_model_config = self.model_config.text_config
else:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
else:
self.text_model_config = self.model_config
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
@@ -947,8 +947,6 @@ class ModelLoader:
quantization_config = ( quantization_config = (
quantization_config or self.model_kwargs["quantization_config"] quantization_config or self.model_kwargs["quantization_config"]
) )
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = load_sharded_model_quant( self.model = load_sharded_model_quant(
self.base_model, self.base_model,
self.model_config, self.model_config,
@@ -969,9 +967,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading() _ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified # Load model with random initialization if specified
if self.cfg.random_init_weights: if self.cfg.random_init_weights:
# AutoModel classes support the from_config method # AutoModel classes support the from_config method
@@ -1026,8 +1021,6 @@ class ModelLoader:
and self.model_type != "AutoModelForCausalLM" and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
): ):
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq: if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
@@ -1043,25 +1036,7 @@ class ModelLoader:
**self.model_kwargs, **self.model_kwargs,
) )
else: else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(self.text_model_config, "max_seq_len")
and self.text_model_config.max_seq_len
and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
self.text_model_config.max_seq_len = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
hasattr(self.text_model_config, "max_sequence_length")
and self.text_model_config.max_sequence_length
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
self.text_model_config.max_sequence_length = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
if self.cfg.gptq: if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1080,8 +1055,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading() _ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1346,8 +1319,6 @@ class ModelLoader:
requires_grad.append(f"{name}: {param.requires_grad}") requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0: if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates") LOG.warning("there are no parameters that require gradient updates")
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if self.cfg.flash_optimum: if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer

0
tests/__init__.py Normal file
View File

View File

@@ -11,7 +11,11 @@ import time
import pytest import pytest
import requests import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1): def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -25,9 +29,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
except ( except (
requests.exceptions.ReadTimeout, requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
) as exc: ) as exc:
if attempt < max_retries - 1: if attempt < max_retries - 1:
time.sleep(delay) wait = 2**attempt * delay # in seconds
time.sleep(wait)
else: else:
raise exc raise exc
@@ -37,6 +43,7 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
@disable_hf_offline
def snapshot_download_w_retry(*args, **kwargs): def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs) return snapshot_download(*args, **kwargs)
@@ -44,19 +51,19 @@ def snapshot_download_w_retry(*args, **kwargs):
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model(): def download_smollm2_135m_model():
# download the model # download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M") snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model(): def download_llama_68m_random_model():
# download the model # download the model
snapshot_download_w_retry("JackFram/llama-68m") snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model(): def download_qwen_2_5_half_billion_model():
# download the model # download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B") snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -101,6 +108,37 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
) )
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset # download the dataset
@@ -109,10 +147,141 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
) )
@pytest.fixture(scope="session", autouse=True)
def download_argilla_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset(): def download_tiny_shakespeare_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"huggyllama/llama-7b",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-3.2-1B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"casperhansen/mistral-7b-instruct-v0.1-awq",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma_2b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"unsloth/gemma-2b-it",
revision="703fb4a",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma2_9b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/gemma-2-9b-it-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mlx_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-2-7b-hf",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture @pytest.fixture
@@ -178,3 +347,34 @@ def cleanup_monkeypatches():
module_globals = module_name_tuple[1] module_globals = module_name_tuple[1]
for module_global in module_globals: for module_global in module_globals:
globals().pop(module_global, None) globals().pop(module_global, None)
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass

View File

@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats from axolotl.core.chat.messages import ChatFormattedChats, Chats
from tests.hf_offline_utils import enable_hf_offline # noqa
@pytest.fixture(scope="session", name="llama_tokenizer") @pytest.fixture(scope="session", name="llama_tokenizer")
@enable_hf_offline
def llama_tokenizer_fixture(): def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B") return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer") @pytest.fixture(scope="session", name="chatml_tokenizer")

View File

@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path from pathlib import Path
import pytest import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -13,6 +12,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg") @pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir): def min_cfg(temp_dir):

View File

@@ -2,15 +2,13 @@
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
class LigerIntegrationTestCase: class LigerIntegrationTestCase:

View File

@@ -8,11 +8,12 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO: class TestGRPO:
""" """

View File

@@ -9,12 +9,13 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -9,10 +9,11 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -14,6 +14,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -23,6 +25,7 @@ class TestDeepseekV3:
Test case for DeepseekV3 models Test case for DeepseekV3 models
""" """
@enable_hf_offline
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sample_packing", "sample_packing",
[True, False], [True, False],
@@ -80,6 +83,7 @@ class TestDeepseekV3:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@enable_hf_offline
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sample_packing", "sample_packing",
[True, False], [True, False],

View File

@@ -5,14 +5,14 @@ E2E tests for llama
import logging import logging
import os import os
from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

85
tests/hf_offline_utils.py Normal file
View File

@@ -0,0 +1,85 @@
"""
test utils for helpers and decorators
"""
import os
from functools import wraps
from huggingface_hub.utils import reset_sessions
def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
import importlib
import datasets
import huggingface_hub.constants
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()
def enable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "1"
reload_modules(True)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper
def disable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "0"
reload_modules(False)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper

View File

@@ -4,12 +4,13 @@ shared fixtures for prompt strategies tests
import pytest import pytest
from datasets import Dataset from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES from axolotl.utils.chat_templates import _CHAT_TEMPLATES
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
@@ -108,31 +109,27 @@ def fixture_toolcalling_dataset():
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True) @pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
def fixture_llama3_tokenizer(): @enable_hf_offline
hf_hub_download( def fixture_llama3_tokenizer(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", download_llama3_8b_instruct_model_fixture,
filename="special_tokens_map.json", ): # pylint: disable=unused-argument,redefined-outer-name
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="tokenizer_config.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True) @pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_smollm2_tokenizer(): def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer(): @enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit" "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
) )
@@ -140,6 +137,7 @@ def fixture_mistralv03_tokenizer():
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True) @pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_phi35_tokenizer(): def fixture_phi35_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
return tokenizer return tokenizer

View File

@@ -11,6 +11,8 @@ from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="alpaca_dataset") @pytest.fixture(name="alpaca_dataset")
def fixture_alpaca_dataset(): def fixture_alpaca_dataset():
@@ -26,6 +28,7 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
@enable_hf_offline
def fixture_tokenizer(): def fixture_tokenizer():
# pylint: disable=all # pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -13,8 +13,11 @@ from axolotl.utils.chat_templates import (
get_chat_template, get_chat_template,
) )
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="llama3_tokenizer") @pytest.fixture(name="llama3_tokenizer")
@enable_hf_offline
def fixture_llama3_tokenizer(): def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")

View File

@@ -17,6 +17,8 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import get_chat_template
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -30,12 +32,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja", "mistralv03_tokenizer_chat_template_jinja",
"[/INST]", "[/INST]",
), ),
( # TODO: temporarily skip gemma due to gemma3 template
"gemma2_tokenizer", # Re-enable on new chat_template implementation for perf
"jinja", # (
"gemma2_tokenizer_chat_template_jinja", # "gemma2_tokenizer",
"<end_of_turn>", # "jinja",
), # "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
("phi35_tokenizer", "phi_35", None, "<|end|>"), ("phi35_tokenizer", "phi_35", None, "<|end|>"),
] ]
@@ -93,7 +97,11 @@ class TestChatTemplateConfigurations:
if ( if (
turn_idx == 0 turn_idx == 0
and turn.get("from") in ["system", "context"] and turn.get("from") in ["system", "context"]
and "mistral" in tokenizer.name_or_path.lower() and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
): ):
assert ( assert (
start_idx == -1 and end_idx == -1 start_idx == -1 and end_idx == -1
@@ -101,6 +109,7 @@ class TestChatTemplateConfigurations:
return True return True
return False return False
@enable_hf_offline
def test_train_on_inputs_true( def test_train_on_inputs_true(
self, self,
tokenizer, tokenizer,

View File

@@ -11,6 +11,8 @@ from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
@@ -78,15 +80,8 @@ def fixture_custom_assistant_dataset():
) )
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
@pytest.fixture(name="phi3_tokenizer") @pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer(): def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct") tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
@@ -94,6 +89,7 @@ def fixture_phi3_tokenizer():
@pytest.fixture(name="gemma_tokenizer") @pytest.fixture(name="gemma_tokenizer")
@enable_hf_offline
def fixture_gemma_tokenizer(): def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a") tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")

View File

@@ -10,6 +10,8 @@ from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="minimal_dpo_cfg") @pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg(): def fixture_cfg():
@@ -34,6 +36,8 @@ class TestDPOChatml:
Test loading DPO preference datasets with chatml formatting Test loading DPO preference datasets with chatml formatting
""" """
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_default(self, minimal_dpo_cfg): def test_default(self, minimal_dpo_cfg):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -8,12 +8,15 @@ from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5 from axolotl.utils.data import encode_pretraining, md5
from tests.hf_offline_utils import enable_hf_offline
class TestEncodePretraining(unittest.TestCase): class TestEncodePretraining(unittest.TestCase):
""" """
test class for encode pretraining and md5 helper test class for encode pretraining and md5 helper
""" """
@enable_hf_offline
def setUp(self): def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens( self.tokenizer.add_special_tokens(

View File

@@ -4,31 +4,37 @@ Test dataset loading under various conditions.
import shutil import shutil
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from conftest import snapshot_download_w_retry import pytest
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from tests.hf_offline_utils import enable_hf_offline
class TestDatasetPreparation(unittest.TestCase):
class TestDatasetPreparation:
"""Test a configured dataloader.""" """Test a configured dataloader."""
def setUp(self) -> None: @pytest.fixture
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
# Alpaca dataset. yield tokenizer_huggyllama
self.dataset = Dataset.from_list(
@pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
[ [
{ {
"instruction": "Evaluate this sentence for spelling and grammar mistakes", "instruction": "Evaluate this sentence for spelling and grammar mistakes",
@@ -38,7 +44,9 @@ class TestDatasetPreparation(unittest.TestCase):
] ]
) )
def test_load_hub(self): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
"""Core use case. Verify that processing data from the hub works""" """Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -55,25 +63,28 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_local_hub(self): @enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self, tokenizer):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded""" """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
) )
# offline mode doesn't actually copy it to local_dir, so we
# have to copy all the contents in the dir manually from the returned snapshot_path
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
# Right now a local copy that doesn't fully conform to a dataset # Right now a local copy that doesn't fully conform to a dataset
@@ -96,9 +107,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -106,11 +115,12 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
def test_load_from_save_to_disk(self): @enable_hf_offline
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset" tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name)) dataset_fixture.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -126,22 +136,21 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_dir_of_parquet(self): @enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded.""" """Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet" tmp_ds_path = tmp_ds_dir / "shard1.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -162,22 +171,21 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_dir_of_json(self): @enable_hf_offline
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a directory of json files can be loaded.""" """Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json" tmp_ds_path = tmp_ds_dir / "shard1.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -198,20 +206,19 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_single_parquet(self): @enable_hf_offline
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single parquet file can be loaded.""" """Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -228,20 +235,19 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_from_single_json(self): @enable_hf_offline
def test_load_from_single_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single json file can be loaded.""" """Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -258,15 +264,15 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@enable_hf_offline
def test_load_hub_with_dpo(self): def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works""" """Verify that processing dpo data from the hub works"""
@@ -285,7 +291,9 @@ class TestDatasetPreparation(unittest.TestCase):
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub_with_revision(self, tokenizer):
"""Verify that processing data from the hub works with a specific revision""" """Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -307,16 +315,17 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self): @enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision""" """Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault( cfg = DictDefault(
@@ -329,22 +338,34 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) # pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
assert len(train_dataset) == 1800 train_dataset, _ = load_prepare_preference_datasets(cfg)
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self): assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self, tokenizer):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision""" """Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
revision="d05c1cb", revision="d05c1cb",
) )
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -365,9 +386,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -375,17 +394,19 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
def test_loading_local_dataset_folder(self): @enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):
"""Verify that a dataset downloaded to a local folder can be loaded""" """Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download_w_retry( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path, local_dir=tmp_ds_path,
) )
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -401,16 +422,10 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,9 +8,8 @@ import hashlib
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS import pytest
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
@@ -19,6 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name): def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
""" """
@@ -214,13 +216,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase): class TestDeduplicateRLDataset:
"""Test a configured dataloader with deduplication.""" """Test a configured dataloader with deduplication."""
def setUp(self) -> None: @pytest.fixture
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") def cfg(self):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) fixture = DictDefault(
self.cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024, "sequence_len": 1024,
@@ -233,34 +234,66 @@ class TestDeduplicateRLDataset(unittest.TestCase):
], ],
} }
) )
yield fixture
def test_load_with_deduplication(self): @enable_hf_offline
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # pylint: disable=duplicate-code
train_dataset, _ = load_prepare_preference_datasets(self.cfg) with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Verify that the dataset has been deduplicated train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
def test_load_without_deduplication(self): # Verify that the dataset has been deduplicated
"""Verify that loading without deduplication retains duplicates.""" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates @enable_hf_offline
assert ( def test_load_without_deduplication(
len(train_dataset) == 1800 * 2 self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
), "Dataset deduplication occurred when it should not have" ):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase): class TestDeduplicateNonRL(unittest.TestCase):
"""Test prepare_dataset function with different configurations.""" """Test prepare_dataset function with different configurations."""
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault( self.cfg_1 = DictDefault(
{ {
"base_model": "huggyllama/llama-7b", "base_model": "huggyllama/llama-7b",
@@ -286,6 +319,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
) )
normalize_config(self.cfg_1) normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_train(self): def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication.""" """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True self.cfg_1.dataset_exact_deduplication = True
@@ -311,6 +346,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Train dataset should have 2000 samples after deduplication.", "Train dataset should have 2000 samples after deduplication.",
) )
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_eval(self): def test_prepare_dataset_with_deduplication_eval(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication.""" """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True self.cfg_1.dataset_exact_deduplication = True
@@ -336,6 +373,8 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Eval dataset should have 2000 samples after deduplication.", "Eval dataset should have 2000 samples after deduplication.",
) )
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_without_deduplication(self): def test_prepare_dataset_without_deduplication(self):
"""Verify that prepare_dataset function processes the dataset correctly without deduplication.""" """Verify that prepare_dataset function processes the dataset correctly without deduplication."""
self.cfg_1.dataset_exact_deduplication = False self.cfg_1.dataset_exact_deduplication = False

View File

@@ -12,6 +12,8 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
@@ -25,6 +27,7 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, num_workers", "batch_size, num_workers",
[ [
@@ -35,11 +38,12 @@ class TestBatchedSamplerPacking:
], ],
) )
@pytest.mark.parametrize("max_seq_length", [4096, 512]) @pytest.mark.parametrize("max_seq_length", [4096, 512])
@enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset( dataset = load_dataset(
"Trelis/tiny-shakespeare", "winglian/tiny-shakespeare",
split="train", split="train",
) )

View File

@@ -10,12 +10,15 @@ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter from axolotl.prompters import AlpacaPrompter
from tests.hf_offline_utils import enable_hf_offline
class TestPacking(unittest.TestCase): class TestPacking(unittest.TestCase):
""" """
Test class for packing dataset sequences Test class for packing dataset sequences
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")

View File

@@ -1,43 +1,60 @@
"""Module for testing streaming dataset sequence packing""" """Module for testing streaming dataset sequence packing"""
import functools import functools
import unittest import random
import string
import pytest import pytest
import torch import torch
from datasets import load_dataset from datasets import IterableDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase): class TestPretrainingPacking:
""" """
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
def setUp(self) -> None: @pytest.fixture
# pylint: disable=duplicate-code def random_text(self):
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") # seed with random.seed(0) for reproducibility
self.tokenizer.pad_token = "</s>" random.seed(0)
@pytest.mark.flaky(retries=3, delay=5) # generate row of random text with "words" of between 2 and 10 characters and
def test_packing_stream_dataset(self): # between 400 to 1200 characters per line
# pylint: disable=duplicate-code def rand_txt():
dataset = load_dataset( return " ".join(
"allenai/c4", [
"en", "".join(
streaming=True, random.choices(string.ascii_lowercase, k=random.randint(2, 10))
)["train"] )
for _ in range(random.randint(50, 200))
]
)
# Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset
def generator():
for row in data:
yield {"text": row}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
cfg = DictDefault( cfg = DictDefault(
{ {
"pretraining_dataset": [ "pretraining_dataset": [
{ {
"path": "allenai/c4", "path": "winglian/tiny-shakespeare",
"name": "en",
"type": "pretrain", "type": "pretrain",
} }
], ],
@@ -54,15 +71,16 @@ class TestPretrainingPacking(unittest.TestCase):
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
dataset, dataset,
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
@@ -78,7 +96,7 @@ class TestPretrainingPacking(unittest.TestCase):
) )
idx = 0 idx = 0
for data in trainer_loader: for data in trainer_loader:
if idx > 10: if idx > 3:
break break
assert data["input_ids"].shape == torch.Size( assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len] [1, original_bsz * cfg.sequence_len]
@@ -95,7 +113,3 @@ class TestPretrainingPacking(unittest.TestCase):
# [1, original_bsz * cfg.sequence_len] # [1, original_bsz * cfg.sequence_len]
# ) # )
idx += 1 idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -5,6 +5,7 @@ import logging
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from datasets import load_dataset from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -22,6 +23,8 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
test_data = { test_data = {
@@ -63,6 +66,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies. Test class for prompt tokenization strategies.
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -119,6 +123,7 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -160,6 +165,7 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
@@ -238,6 +244,7 @@ If a question does not make any sense, or is not factually coherent, explain why
class OrpoTokenizationTest(unittest.TestCase): class OrpoTokenizationTest(unittest.TestCase):
"""test case for the ORPO tokenization""" """test case for the ORPO tokenization"""
@enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
@@ -262,6 +269,7 @@ class OrpoTokenizationTest(unittest.TestCase):
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train" "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0]) ).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self): def test_orpo_integration(self):
strat = load( strat = load(
self.tokenizer, self.tokenizer,

View File

@@ -9,12 +9,15 @@ import pytest
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers: class TestTokenizers:
""" """
test class for the load_tokenizer fn test class for the load_tokenizer fn
""" """
@enable_hf_offline
def test_default_use_fast(self): def test_default_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -24,6 +27,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__ assert "Fast" in tokenizer.__class__.__name__
@enable_hf_offline
def test_dont_use_fast(self): def test_dont_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -34,6 +38,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__
@enable_hf_offline
def test_special_tokens_modules_to_save(self): def test_special_tokens_modules_to_save(self):
# setting special_tokens to new token # setting special_tokens to new token
cfg = DictDefault( cfg = DictDefault(
@@ -68,6 +73,7 @@ class TestTokenizers:
) )
load_tokenizer(cfg) load_tokenizer(cfg)
@enable_hf_offline
def test_add_additional_special_tokens(self): def test_add_additional_special_tokens(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -83,6 +89,7 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
@enable_hf_offline
def test_added_tokens_overrides(self, temp_dir): def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -104,11 +111,12 @@ class TestTokenizers:
128042 128042
] ]
@enable_hf_offline
def test_added_tokens_overrides_with_toolargeid(self, temp_dir): def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
# use with tokenizer that has reserved_tokens in added_tokens # use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B", "tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"}, "added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir, "output_dir": temp_dir,
} }

0
tests/utils/__init__.py Normal file
View File