fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)
* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path * fix: normalize config
This commit is contained in:
@@ -119,6 +119,10 @@ def normalize_config(cfg):
|
||||
model_config = load_model_config(cfg)
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
cfg.tokenizer_config = (
|
||||
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||
)
|
||||
|
||||
# figure out if the model is llama
|
||||
cfg.is_llama_derived_model = (
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
|
||||
@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
|
||||
split="train",
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
|
||||
@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
cfg.tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""
|
||||
unit tests for axolotl.core.trainer_builder
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
@@ -34,6 +36,10 @@ def fixture_cfg():
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer(cfg):
|
||||
|
||||
Reference in New Issue
Block a user