diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 15902516d..a68377f07 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -25,7 +25,11 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.config import ( + normalize_cfg_datasets, + normalize_config, + validate_config, +) from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process @@ -289,6 +293,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): normalize_config(cfg) + normalize_cfg_datasets(cfg) + setup_wandb_env_vars(cfg) setup_mlflow_env_vars(cfg) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9a69184d8..f88490383 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -150,6 +150,21 @@ def normalize_config(cfg): log_gpu_memory_usage(LOG, "baseline", cfg.device) +def normalize_cfg_datasets(cfg): + """ + helpers for mapping chat_template to various dataset configurations as necessary + """ + + if cfg.chat_template and cfg.chat_template == "chatml": + if cfg.datasets: + for idx, ds_cfg in enumerate(cfg.datasets): + if ds_cfg.type == "sharegpt" and not ds_cfg.conversation: + LOG.info( + f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template" + ) + cfg.datasets[idx].conversation = "chatml" + + def validate_config(cfg): """ This is a "pre-validation" step that handles the yaml configuration before we have any diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 1397b23af..004d0068e 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -3,7 +3,7 @@ Test classes for checking functionality of the cfg normalization """ import unittest -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_cfg_datasets, normalize_config from axolotl.utils.dict import DictDefault @@ -44,3 +44,26 @@ class NormalizeConfigTestCase(unittest.TestCase): normalize_config(cfg) assert cfg.base_model_config == cfg.base_model + + def test_chat_template_chatml(self): + cfg = DictDefault( + { + "chat_template": "chatml", + "datasets": [ + { + "path": "lorem/ipsum", + "type": "sharegpt", + "conversation": "vicuna_v1.1", + }, + { + "path": "sit/amet", + "type": "sharegpt", + }, + ], + } + ) + + normalize_cfg_datasets(cfg) + + assert cfg.datasets[0].conversation == "vicuna_v1.1" + assert cfg.datasets[1].conversation == "chatml"