Compare commits

..

3 Commits

Author SHA1 Message Date
Sunny Liu
c5c01c11d8 fix dumb mistakes 2025-03-27 13:33:52 -04:00
Sunny Liu
00ebf2faf9 message key checking 2025-03-27 13:29:17 -04:00
Sunny Liu
641e84188b add chat conversion for multiple choice format 2025-03-27 10:51:24 -04:00
44 changed files with 461 additions and 875 deletions

View File

@@ -136,4 +136,4 @@ jobs:
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
modal run cicd.tests

View File

@@ -63,7 +63,7 @@ jobs:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
@@ -137,7 +137,7 @@ jobs:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
@@ -171,9 +171,6 @@ jobs:
run: |
axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -232,7 +229,7 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -279,4 +276,4 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
modal run cicd.tests

View File

@@ -1,4 +1,3 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma3
chat_template: gemma3_text
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template

View File

@@ -19,6 +19,7 @@ val_set_size: 0.0
output_dir: ./outputs/lora-out
dataset_exact_deduplication: true
test_value: true
sequence_len: 4096
sample_packing: true

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.5
liger-kernel==0.5.3
# END section
packaging==23.2
@@ -15,7 +15,7 @@ peft==0.15.0
transformers==4.50.0
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
datasets==3.4.1
deepspeed==0.16.4
trl==0.15.1

View File

@@ -25,8 +25,8 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")

View File

@@ -15,6 +15,7 @@ import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache, HybridCache
@@ -32,8 +33,6 @@ from transformers.utils import (
)
from transformers.utils.deprecation import deprecate_kwarg
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
_PATCH_OPTS: PatchOptions | None = None
@@ -135,17 +134,25 @@ def cce_forward(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
)
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**loss_kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
@@ -346,7 +353,6 @@ def cce_forward_multimodal(
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**lm_kwargs,
)
else:

View File

@@ -1,40 +0,0 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Monkeypatch for apply_lce to add softcap."""
import torch
from cut_cross_entropy import linear_cross_entropy
from cut_cross_entropy.transformers.utils import PatchOptions
def apply_lce(
e: torch.Tensor,
c: torch.Tensor,
labels: torch.Tensor,
opts: PatchOptions,
bias: torch.Tensor | None = None,
softcap: float | None = None,
**loss_kwargs,
) -> torch.Tensor:
"""Monkey patch for apply_lce to support softcap kwarg."""
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
cce_kwargs = opts.to_kwargs()
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
cce_kwargs["reduction"] = "sum"
else:
num_items_in_batch = None
loss = linear_cross_entropy(
e,
c,
labels.to(e.device),
bias=bias,
shift=True,
softcap=softcap,
**cce_kwargs,
)
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return loss

View File

@@ -20,26 +20,6 @@ liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Supported Models
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- granite
- jamba
- llama
- mistral
- mixtral
- mllama
- mllama_text_model
- olmo2
- paligemma
- phi3
- qwen2
- qwen2_5_vl
- qwen2_vl
## Citation
```bib

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -42,18 +41,11 @@ class LigerPlugin(BasePlugin):
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
@@ -90,8 +82,6 @@ class LigerPlugin(BasePlugin):
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
@@ -114,51 +104,15 @@ class LigerPlugin(BasePlugin):
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type in ["deepseek_v3"]:
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -1,113 +1,153 @@
"""
Hijack the LlamaAttention forward method to use xformers if available.
Updated for transformers v4.50.0.
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
from typing import Optional
import logging
import warnings
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
def xformers_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs, # pylint: disable=unused-argument
):
"""
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
Args:
module: The LlamaAttention module
query: Query states of shape [batch, num_heads, seq_len, head_dim]
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
attention_mask: Attention mask
scaling: Scaling factor for attention scores
dropout: Dropout probability
Returns:
attn_output: Output of xformers memory-efficient attention
attn_weights: None
"""
# First, handle grouped-query attention (GQA)
# We need to repeat key and value states to match the number of query heads
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Determine if we need a causal mask
is_causal = getattr(module, "is_causal", True)
# Set up the attention bias for xformers
if is_causal:
# Use xformers built-in causal mask
attn_bias = xformers.ops.LowerTriangularMask()
elif attention_mask is not None:
# For non-causal attention with a mask, we'd need to convert the mask
# This is a simplification - you might need to adapt based on your mask format
attn_bias = attention_mask
else:
# No mask needed
attn_bias = None
# Apply xformers memory-efficient attention
attn_output = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attn_bias,
p=dropout if module.training else 0.0,
scale=scaling,
)
# Reshape back to [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2)
return attn_output, None # Return None for attn_weights to match interface
logging.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention():
"""
Patch the LlamaAttention forward method to use xformers if available.
"""
if not XFORMERS_AVAILABLE:
raise ValueError(
"xformers not available. Please install it following axolotl's requirements."
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
import transformers.models.llama.modeling_llama as llama_modeling
#
# xformers-attn start
#
# Add xformers to the available attention implementations
llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Create a wrapper for the original LlamaAttention forward method
original_forward = llama_modeling.LlamaAttention.forward
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
def patched_forward(self, *args, **kwargs):
# Set the attention implementation to xformers
# pylint: disable=protected-access
self.config._attn_implementation = "xformers"
return original_forward(self, *args, **kwargs)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# Apply the patch
llama_modeling.LlamaAttention.forward = patched_forward
#
# xformers-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -1,5 +1,6 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
import ast
from copy import deepcopy
from typing import Optional
@@ -75,6 +76,49 @@ class ProcessingStrategy:
result["messages"] = messages
return result
def convert_multiple_choice_to_multimedia_messages(
messages: dict,
) -> list[dict]:
def construct_prompt(sample):
question = sample["question"]
options = sample["options"]
if isinstance(options, str):
options = ast.literal_eval(options)
example = ""
start_chr = "A"
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
empty_prompt = empty_prompt_sample_structure.format(question, example)
return empty_prompt
new_messages = []
user_content = construct_prompt(messages)
assistant_response = messages["answer"]
new_messages.append(
{"role": "user", "content": [{"type": "text", "text": user_content}]}
)
new_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_response}],
}
)
return new_messages
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
"""Convert regular messages format to Messages format with content type"""
@@ -106,39 +150,51 @@ class ProcessingStrategy:
processed_examples = []
for example in examples:
if not ("messages" in example or "conversations" in example):
if not (
"messages" in example
or "conversations" in example
or "question" in example
):
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
"Only `messages`, `conversations`, and `question` message keys are currently supported."
)
processed_example = None
if "messages" in example: # OpenAI format
processed_example = example
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
elif "question" in example: # Multiple choice format
processed_example = {}
processed_example["messages"] = (
convert_multiple_choice_to_multimedia_messages(example)
)
else: # Legacy format
processed_example = convert_legacy_format(example)
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
# find the image key if it exists
possible_image_keys = ["images", "image"]
image_key = None
for key in possible_image_keys:
if key in processed_example:
image_key = key
break
# if the image key exists, add the image to the first message
if image_key is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
image_value = processed_example[image_key][0]
image_keys = []
for key in example.keys():
if "image" in key:
image_keys.append(key)
for im_key in image_keys:
if example[im_key] is None:
continue
if isinstance(example[im_key], list):
if len(example[im_key]) == 0:
continue
image_value = example[im_key][0]
else:
image_value = example[im_key]
# Handle image loading (Image, url, path, base64)
image_value = load_image(image_value)
if self.image_size is not None:
@@ -163,33 +219,12 @@ class ProcessingStrategy:
color=padding_color,
)
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
ind_to_add = None
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_examples.append(processed_example)

View File

@@ -411,15 +411,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range")
# mistral/gemma3 does not output message if it contains only system message
# mistral does not output message if it contains only system message
if (
turn_idx == 0
and turns[0].get("role") == "system"
and (
"mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer
or "gemma" in self.tokenizer.name_or_path.lower()
)
and "mistral" in self.tokenizer.name_or_path.lower()
):
return -1, -1

View File

@@ -14,7 +14,6 @@ import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -303,7 +302,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated

View File

@@ -6,12 +6,8 @@ from pathlib import Path
from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import (
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HFValidationError
from axolotl.utils.dict import DictDefault
@@ -74,25 +70,20 @@ def load_dataset_w_config(
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
snapshot_download(
repo_id=config_dataset.path,
repo_type="dataset",
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
ignore_patterns=["*"],
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (
RepositoryNotFoundError,
RevisionNotFoundError,
FileNotFoundError,
ConnectionError,
HFValidationError,
ValueError,
):
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False

View File

@@ -8,7 +8,7 @@ import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
@@ -25,7 +25,7 @@ from peft import (
prepare_model_for_kbit_training,
)
from torch import nn
from transformers import (
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
@@ -39,7 +39,6 @@ from transformers import (
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
@@ -108,21 +107,14 @@ def get_module_class_from_name(module, name):
return None
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
# Set use_cache to False
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
# For multimodal configs, use_cache is set in the text_config
if hasattr(model_config, "get_text_config"):
text_config = model_config.get_text_config()
if hasattr(text_config, "use_cache"):
text_config.use_cache = False
else:
raise ValueError(
"No text config found for multimodal model. Please raise an Issue with model details."
)
if hasattr(model_config, "text_config"):
model_config = model_config.text_config
model_config.use_cache = False
elif hasattr(model_config, "get_text_config"):
model_config = model_config.get_text_config()
model_config.use_cache = False
# check if image_size is not set and load image size from model config if available
if (
@@ -531,6 +523,14 @@ class ModelLoader:
# init model config
self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
if hasattr(self.model_config, "text_config"):
self.text_model_config = self.model_config.text_config
else:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
else:
self.text_model_config = self.model_config
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
@@ -947,6 +947,8 @@ class ModelLoader:
quantization_config = (
quantization_config or self.model_kwargs["quantization_config"]
)
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = load_sharded_model_quant(
self.base_model,
self.model_config,
@@ -967,6 +969,9 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
@@ -1021,6 +1026,8 @@ class ModelLoader:
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
@@ -1036,7 +1043,25 @@ class ModelLoader:
**self.model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(self.text_model_config, "max_seq_len")
and self.text_model_config.max_seq_len
and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
self.text_model_config.max_seq_len = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
hasattr(self.text_model_config, "max_sequence_length")
and self.text_model_config.max_sequence_length
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
self.text_model_config.max_sequence_length = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1055,6 +1080,8 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1319,6 +1346,8 @@ class ModelLoader:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer

View File

View File

@@ -11,11 +11,7 @@ import time
import pytest
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -29,11 +25,9 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
) as exc:
if attempt < max_retries - 1:
wait = 2**attempt * delay # in seconds
time.sleep(wait)
time.sleep(delay)
else:
raise exc
@@ -43,7 +37,6 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5)
@disable_hf_offline
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@@ -51,19 +44,19 @@ def snapshot_download_w_retry(*args, **kwargs):
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
snapshot_download_w_retry("JackFram/llama-68m")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
@pytest.fixture(scope="session", autouse=True)
@@ -108,37 +101,6 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
)
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
@@ -147,141 +109,10 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset():
# download the dataset
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"huggyllama/llama-7b",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-3.2-1B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"casperhansen/mistral-7b-instruct-v0.1-awq",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma_2b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"unsloth/gemma-2b-it",
revision="703fb4a",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma2_9b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/gemma-2-9b-it-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mlx_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-2-7b-hf",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
@pytest.fixture
@@ -347,34 +178,3 @@ def cleanup_monkeypatches():
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass

View File

@@ -10,13 +10,10 @@ from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
from tests.hf_offline_utils import enable_hf_offline # noqa
@pytest.fixture(scope="session", name="llama_tokenizer")
@enable_hf_offline
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")

View File

@@ -5,6 +5,7 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -12,8 +13,6 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):

View File

@@ -2,13 +2,15 @@
Simple end-to-end test for Liger integration
"""
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
from ..utils import check_model_output_exists
class LigerIntegrationTestCase:

View File

@@ -8,12 +8,11 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO:
"""

View File

@@ -9,13 +9,12 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -9,11 +9,10 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -14,8 +14,6 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -25,7 +23,6 @@ class TestDeepseekV3:
Test case for DeepseekV3 models
"""
@enable_hf_offline
@pytest.mark.parametrize(
"sample_packing",
[True, False],
@@ -83,7 +80,6 @@ class TestDeepseekV3:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@enable_hf_offline
@pytest.mark.parametrize(
"sample_packing",
[True, False],

View File

@@ -5,14 +5,14 @@ E2E tests for llama
import logging
import os
from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,85 +0,0 @@
"""
test utils for helpers and decorators
"""
import os
from functools import wraps
from huggingface_hub.utils import reset_sessions
def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
import importlib
import datasets
import huggingface_hub.constants
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()
def enable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "1"
reload_modules(True)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper
def disable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "0"
reload_modules(False)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper

View File

@@ -4,13 +4,12 @@ shared fixtures for prompt strategies tests
import pytest
from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
@@ -109,27 +108,31 @@ def fixture_toolcalling_dataset():
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_llama3_tokenizer(
download_llama3_8b_instruct_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
def fixture_llama3_tokenizer():
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="special_tokens_map.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="tokenizer_config.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
)
@@ -137,7 +140,6 @@ def fixture_mistralv03_tokenizer(
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_phi35_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
return tokenizer

View File

@@ -11,8 +11,6 @@ from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="alpaca_dataset")
def fixture_alpaca_dataset():
@@ -28,7 +26,6 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer")
@enable_hf_offline
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -13,11 +13,8 @@ from axolotl.utils.chat_templates import (
get_chat_template,
)
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="llama3_tokenizer")
@enable_hf_offline
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")

View File

@@ -17,8 +17,6 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@@ -32,14 +30,12 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja",
"[/INST]",
),
# TODO: temporarily skip gemma due to gemma3 template
# Re-enable on new chat_template implementation for perf
# (
# "gemma2_tokenizer",
# "jinja",
# "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
(
"gemma2_tokenizer",
"jinja",
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
("phi35_tokenizer", "phi_35", None, "<|end|>"),
]
@@ -97,11 +93,7 @@ class TestChatTemplateConfigurations:
if (
turn_idx == 0
and turn.get("from") in ["system", "context"]
and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
and "mistral" in tokenizer.name_or_path.lower()
):
assert (
start_idx == -1 and end_idx == -1
@@ -109,7 +101,6 @@ class TestChatTemplateConfigurations:
return True
return False
@enable_hf_offline
def test_train_on_inputs_true(
self,
tokenizer,

View File

@@ -11,8 +11,6 @@ from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
@@ -80,8 +78,15 @@ def fixture_custom_assistant_dataset():
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
@pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
@@ -89,7 +94,6 @@ def fixture_phi3_tokenizer():
@pytest.fixture(name="gemma_tokenizer")
@enable_hf_offline
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")

View File

@@ -10,8 +10,6 @@ from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():
@@ -36,8 +34,6 @@ class TestDPOChatml:
Test loading DPO preference datasets with chatml formatting
"""
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_default(self, minimal_dpo_cfg):
cfg = DictDefault(
{

View File

@@ -8,15 +8,12 @@ from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from tests.hf_offline_utils import enable_hf_offline
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
@enable_hf_offline
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(

View File

@@ -4,37 +4,31 @@ Test dataset loading under various conditions.
import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
from conftest import snapshot_download_w_retry
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from tests.hf_offline_utils import enable_hf_offline
class TestDatasetPreparation:
class TestDatasetPreparation(unittest.TestCase):
"""Test a configured dataloader."""
@pytest.fixture
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama
@pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
# Alpaca dataset.
self.dataset = Dataset.from_list(
[
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
@@ -44,9 +38,7 @@ class TestDatasetPreparation:
]
)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
def test_load_hub(self):
"""Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
@@ -63,28 +55,25 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self, tokenizer):
def test_load_local_hub(self):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
# offline mode doesn't actually copy it to local_dir, so we
# have to copy all the contents in the dir manually from the returned snapshot_path
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
# Right now a local copy that doesn't fully conform to a dataset
@@ -107,7 +96,9 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -115,12 +106,11 @@ class TestDatasetPreparation:
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
def test_load_from_save_to_disk(self):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
dataset_fixture.save_to_disk(str(tmp_ds_name))
self.dataset.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -136,21 +126,22 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
def test_load_from_dir_of_parquet(self):
"""Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
dataset_fixture.to_parquet(tmp_ds_path)
self.dataset.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -171,21 +162,22 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
def test_load_from_dir_of_json(self):
"""Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json"
dataset_fixture.to_json(tmp_ds_path)
self.dataset.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -206,19 +198,20 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
def test_load_from_single_parquet(self):
"""Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
dataset_fixture.to_parquet(tmp_ds_path)
self.dataset.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -235,19 +228,20 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_json(self, tokenizer, dataset_fixture):
def test_load_from_single_json(self):
"""Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
dataset_fixture.to_json(tmp_ds_path)
self.dataset.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -264,15 +258,15 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@enable_hf_offline
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
@@ -291,9 +285,7 @@ class TestDatasetPreparation:
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub_with_revision(self, tokenizer):
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
@@ -315,17 +307,16 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
@@ -338,34 +329,22 @@ class TestDatasetPreparation:
}
)
# pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self, tokenizer):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -386,7 +365,9 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -394,19 +375,17 @@ class TestDatasetPreparation:
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):
def test_loading_local_dataset_folder(self):
"""Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -422,10 +401,16 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,8 +8,9 @@ import hashlib
import unittest
from unittest.mock import patch
import pytest
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
@@ -18,9 +19,6 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""
@@ -216,12 +214,13 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset:
class TestDeduplicateRLDataset(unittest.TestCase):
"""Test a configured dataloader with deduplication."""
@pytest.fixture
def cfg(self):
fixture = DictDefault(
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
@@ -234,66 +233,34 @@ class TestDeduplicateRLDataset:
],
}
)
yield fixture
@enable_hf_offline
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
def test_load_with_deduplication(self):
"""Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
def test_load_without_deduplication(self):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
@enable_hf_offline
def test_load_without_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase):
"""Test prepare_dataset function with different configurations."""
@enable_hf_offline
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",
@@ -319,8 +286,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True
@@ -346,8 +311,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Train dataset should have 2000 samples after deduplication.",
)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_with_deduplication_eval(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True
@@ -373,8 +336,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
"Eval dataset should have 2000 samples after deduplication.",
)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_prepare_dataset_without_deduplication(self):
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
self.cfg_1.dataset_exact_deduplication = False

View File

@@ -12,8 +12,6 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
@@ -27,7 +25,6 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences
"""
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize(
"batch_size, num_workers",
[
@@ -38,12 +35,11 @@ class TestBatchedSamplerPacking:
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
@enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
"winglian/tiny-shakespeare",
"Trelis/tiny-shakespeare",
split="train",
)

View File

@@ -10,15 +10,12 @@ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from tests.hf_offline_utils import enable_hf_offline
class TestPacking(unittest.TestCase):
"""
Test class for packing dataset sequences
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")

View File

@@ -1,60 +1,43 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import random
import string
import unittest
import pytest
import torch
from datasets import IterableDataset
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
class TestPretrainingPacking:
class TestPretrainingPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""
@pytest.fixture
def random_text(self):
# seed with random.seed(0) for reproducibility
random.seed(0)
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
# generate row of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
def rand_txt():
return " ".join(
[
"".join(
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
)
for _ in range(random.randint(50, 200))
]
)
# Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset
def generator():
for row in data:
yield {"text": row}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
@pytest.mark.flaky(retries=3, delay=5)
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"allenai/c4",
"en",
streaming=True,
)["train"]
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "winglian/tiny-shakespeare",
"path": "allenai/c4",
"name": "en",
"type": "pretrain",
}
],
@@ -71,16 +54,15 @@ class TestPretrainingPacking:
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer_huggyllama,
self.tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
tokenizer_huggyllama,
self.tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
@@ -96,7 +78,7 @@ class TestPretrainingPacking:
)
idx = 0
for data in trainer_loader:
if idx > 3:
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
@@ -113,3 +95,7 @@ class TestPretrainingPacking:
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -5,7 +5,6 @@ import logging
import unittest
from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -23,8 +22,6 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl")
test_data = {
@@ -66,7 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -123,7 +119,6 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -165,7 +160,6 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
@@ -244,7 +238,6 @@ If a question does not make any sense, or is not factually coherent, explain why
class OrpoTokenizationTest(unittest.TestCase):
"""test case for the ORPO tokenization"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained(
@@ -269,7 +262,6 @@ class OrpoTokenizationTest(unittest.TestCase):
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self):
strat = load(
self.tokenizer,

View File

@@ -9,15 +9,12 @@ import pytest
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers:
"""
test class for the load_tokenizer fn
"""
@enable_hf_offline
def test_default_use_fast(self):
cfg = DictDefault(
{
@@ -27,7 +24,6 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__
@enable_hf_offline
def test_dont_use_fast(self):
cfg = DictDefault(
{
@@ -38,7 +34,6 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__
@enable_hf_offline
def test_special_tokens_modules_to_save(self):
# setting special_tokens to new token
cfg = DictDefault(
@@ -73,7 +68,6 @@ class TestTokenizers:
)
load_tokenizer(cfg)
@enable_hf_offline
def test_add_additional_special_tokens(self):
cfg = DictDefault(
{
@@ -89,7 +83,6 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001
@enable_hf_offline
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
@@ -111,12 +104,11 @@ class TestTokenizers:
128042
]
@enable_hf_offline
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}