fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)
* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path * fix: normalize config
This commit is contained in:
@@ -119,6 +119,10 @@ def normalize_config(cfg):
|
|||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
cfg.tokenizer_config = (
|
||||||
|
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||||
|
)
|
||||||
|
|
||||||
# figure out if the model is llama
|
# figure out if the model is llama
|
||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
split="train",
|
split="train",
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = cfg.tokenizer_config
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
tokenizer_config,
|
cfg.tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
unit tests for axolotl.core.trainer_builder
|
unit tests for axolotl.core.trainer_builder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cfg")
|
@pytest.fixture(name="cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
return DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||||
"model_type": "AutoModelForCausalLM",
|
"model_type": "AutoModelForCausalLM",
|
||||||
@@ -34,6 +36,10 @@ def fixture_cfg():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer(cfg):
|
def fixture_tokenizer(cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user