fix: minor patches for multimodal (#2441)

* fix: update chat_template

* fix: handle gemma3 showing a lot of no content for turn 0

* fix: remove unknown config from examples

* fix: test

* fix: temporary disable gemma2 test

* fix: stop overwriting config.text_config unnecessarily

* fix: handling of set cache to the text_config section

* feat: add liger gemma support and bump liger to 0.5.5

* fix: add double use_cache setting

* fix: add support for final_logit_softcap in CCE for gemma2/3

* fix: set use_cache before model load

* feat: add missing layernorm override

* fix: handle gemma3 rmsnorm

* fix: use wrapper to pass dim as hidden_size

* fix: change dim to positional

* fix: patch with wrong mlp

* chore: refactor use_cache handling

* fix import issues

* fix tests.e2e.utils import

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-03-31 13:40:12 +07:00
committed by GitHub
parent 4ba80a0e5a
commit cf0c79d52e
38 changed files with 287 additions and 188 deletions

0
tests/__init__.py Normal file
View File

View File

@@ -14,7 +14,8 @@ import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1):

View File

@@ -6,11 +6,12 @@ import unittest
import pytest
from transformers import AddedToken, AutoTokenizer
from utils import enable_hf_offline
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
from tests.hf_offline_utils import enable_hf_offline # noqa
@pytest.fixture(scope="session", name="llama_tokenizer")
@enable_hf_offline

View File

@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -13,6 +12,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):

View File

@@ -2,15 +2,13 @@
Simple end-to-end test for Liger integration
"""
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
class LigerIntegrationTestCase:

View File

@@ -8,11 +8,12 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO:
"""

View File

@@ -9,12 +9,13 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -9,10 +9,11 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -7,7 +7,6 @@ import os
from pathlib import Path
import pytest
from utils import enable_hf_offline
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -15,6 +14,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -5,14 +5,14 @@ E2E tests for llama
import logging
import os
from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

85
tests/hf_offline_utils.py Normal file
View File

@@ -0,0 +1,85 @@
"""
test utils for helpers and decorators
"""
import os
from functools import wraps
from huggingface_hub.utils import reset_sessions
def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
import importlib
import datasets
import huggingface_hub.constants
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()
def enable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "1"
reload_modules(True)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper
def disable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "0"
reload_modules(False)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper

View File

@@ -5,11 +5,12 @@ shared fixtures for prompt strategies tests
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():

View File

@@ -6,12 +6,13 @@ import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="alpaca_dataset")
def fixture_alpaca_dataset():

View File

@@ -6,7 +6,6 @@ import unittest
import pytest
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.utils.chat_templates import (
_CHAT_TEMPLATES,
@@ -14,6 +13,8 @@ from axolotl.utils.chat_templates import (
get_chat_template,
)
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="llama3_tokenizer")
@enable_hf_offline

View File

@@ -9,7 +9,6 @@ import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import PreTrainedTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
@@ -18,6 +17,8 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@@ -31,12 +32,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja",
"[/INST]",
),
(
"gemma2_tokenizer",
"jinja",
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
# TODO: temporarily skip gemma due to gemma3 template
# Re-enable on new chat_template implementation for perf
# (
# "gemma2_tokenizer",
# "jinja",
# "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
("phi35_tokenizer", "phi_35", None, "<|end|>"),
]
@@ -94,7 +97,11 @@ class TestChatTemplateConfigurations:
if (
turn_idx == 0
and turn.get("from") in ["system", "context"]
and "mistral" in tokenizer.name_or_path.lower()
and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
):
assert (
start_idx == -1 and end_idx == -1

View File

@@ -7,11 +7,12 @@ import unittest
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():

View File

@@ -5,12 +5,13 @@ Tests for loading DPO preference datasets with chatml formatting
import unittest
import pytest
from utils import enable_hf_offline
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():

View File

@@ -5,10 +5,11 @@ test module for the axolotl.utils.data module
import unittest
from transformers import LlamaTokenizer
from utils import enable_hf_offline
from axolotl.utils.data import encode_pretraining, md5
from tests.hf_offline_utils import enable_hf_offline
class TestEncodePretraining(unittest.TestCase):
"""

View File

@@ -8,20 +8,21 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from utils import enable_hf_offline
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from tests.hf_offline_utils import enable_hf_offline
class TestDatasetPreparation:
"""Test a configured dataloader."""

View File

@@ -9,9 +9,7 @@ import unittest
from unittest.mock import patch
import pytest
from constants import ALPACA_MESSAGES_CONFIG_REVISION
from datasets import Dataset
from utils import enable_hf_offline
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
@@ -20,6 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""

View File

@@ -4,7 +4,6 @@ import pytest
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
@@ -13,6 +12,8 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():

View File

@@ -5,12 +5,13 @@ from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from tests.hf_offline_utils import enable_hf_offline
class TestPacking(unittest.TestCase):
"""

View File

@@ -8,7 +8,6 @@ from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
@@ -24,6 +23,8 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl")
test_data = {

View File

@@ -5,11 +5,12 @@ Test cases for the tokenizer loading
import unittest
import pytest
from utils import enable_hf_offline
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers:
"""

View File

@@ -1,85 +0,0 @@
"""
test utils for helpers and decorators
"""
import os
from functools import wraps
from huggingface_hub.utils import reset_sessions
def reload_modules(hf_hub_offline):
# Force reload of the modules that check this variable
import importlib
import datasets
import huggingface_hub.constants
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()
def enable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "1"
reload_modules(True)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper
def disable_hf_offline(test_func):
"""
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
:param test_func:
:return:
"""
@wraps(test_func)
def wrapper(*args, **kwargs):
# Save the original value of HF_HUB_OFFLINE environment variable
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
# Set HF_OFFLINE environment variable to True
os.environ["HF_HUB_OFFLINE"] = "0"
reload_modules(False)
try:
# Run the test function
return test_func(*args, **kwargs)
finally:
# Restore the original value of HF_HUB_OFFLINE environment variable
if original_hf_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
reload_modules(bool(original_hf_offline))
else:
del os.environ["HF_HUB_OFFLINE"]
reload_modules(False)
return wrapper