fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)

* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path

* fix: normalize config
This commit is contained in:
NanoCode012
2024-03-25 15:34:54 +09:00
committed by GitHub
parent 324d59ea0d
commit ff939d8a64
4 changed files with 13 additions and 4 deletions

View File

@@ -119,6 +119,10 @@ def normalize_config(cfg):
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
# figure out if the model is llama # figure out if the model is llama
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama") (hasattr(model_config, "model_type") and model_config.model_type == "llama")

View File

@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
split="train", split="train",
) -> Tuple[DatasetDict, List[Prompter]]: ) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = cfg.tokenizer_config
ds_hash = str( ds_hash = str(
md5( md5(
( (

View File

@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type: if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config, cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast, use_fast=use_fast,
**tokenizer_kwargs, **tokenizer_kwargs,

View File

@@ -1,16 +1,18 @@
""" """
unit tests for axolotl.core.trainer_builder unit tests for axolotl.core.trainer_builder
""" """
import pytest import pytest
from axolotl.core.trainer_builder import HFDPOTrainerBuilder from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@pytest.fixture(name="cfg") @pytest.fixture(name="cfg")
def fixture_cfg(): def fixture_cfg():
return DictDefault( cfg = DictDefault(
{ {
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"model_type": "AutoModelForCausalLM", "model_type": "AutoModelForCausalLM",
@@ -34,6 +36,10 @@ def fixture_cfg():
} }
) )
normalize_config(cfg)
return cfg
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(cfg): def fixture_tokenizer(cfg):