feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -109,6 +110,7 @@ class TestKnowledgeDistillation:
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
@@ -76,7 +76,9 @@ class TestFAXentropyLlama:
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -73,6 +73,8 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -108,6 +110,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -153,6 +157,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -198,6 +204,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -242,6 +250,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -289,6 +299,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -353,6 +365,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -56,6 +56,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,6 +65,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -118,6 +120,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -157,6 +161,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
@@ -56,6 +56,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -99,6 +101,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -138,6 +142,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
@@ -69,6 +69,8 @@ class TestPretrainLlama:
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -62,6 +62,8 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestMamba(unittest.TestCase):
|
||||
"save_safetensors": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,8 @@ class TestMistral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -106,6 +108,8 @@ class TestMistral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -69,6 +69,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -123,6 +125,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -180,6 +184,8 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -233,6 +239,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
@@ -281,6 +289,8 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -103,6 +105,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -139,6 +143,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestPackedLlama(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -61,6 +61,7 @@ class TestPhi(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -40,8 +40,10 @@ class TestE2eQwen:
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -66,6 +66,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -7,6 +7,7 @@ from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||
|
||||
|
||||
@@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
||||
modified_template = template.replace(old_date_logic, new_date_logic)
|
||||
|
||||
return modified_template
|
||||
|
||||
|
||||
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
|
||||
def fixture_chat_template_jinja_with_optional_fields() -> str:
|
||||
return """{% for message in messages %}
|
||||
{{'<|im_start|>'}}{{ message['role'] }}
|
||||
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
|
||||
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
|
||||
{{ message['content'] }}{{'<|im_end|>'}}
|
||||
{% endfor %}"""
|
||||
|
||||
|
||||
@pytest.fixture(name="basic_jinja_template_analyzer")
|
||||
def basic_jinja_template_analyzer():
|
||||
return JinjaTemplateAnalyzer(
|
||||
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'user' %}{{'<|user|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
|
||||
' }}{% else %}{{ eos_token }}{% endif %}"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="mistral_jinja_template_analyzer")
|
||||
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
|
||||
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)
|
||||
|
||||
@@ -38,6 +38,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -74,8 +78,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -86,7 +92,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
@@ -114,8 +120,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
phi35_tokenizer,
|
||||
chat_template=get_chat_template("phi_35"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -126,7 +134,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -170,9 +178,11 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -185,7 +195,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
prompt_tokens = strategy.prompter.build_prompt(
|
||||
assistant_dataset[0]["messages"], False
|
||||
)
|
||||
@@ -230,8 +240,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -239,7 +252,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["gpt"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -287,8 +300,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -296,7 +312,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -344,8 +360,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -353,7 +372,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["system", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -417,8 +436,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -486,8 +504,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
|
||||
@@ -3,7 +3,6 @@ tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -123,15 +122,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -180,15 +179,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -241,20 +240,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -307,15 +301,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["human", "assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -360,8 +354,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -369,7 +363,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=[],
|
||||
train_on_eos="none", # Add this line
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
|
||||
@@ -400,8 +394,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -409,7 +403,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="all",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -446,8 +440,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -455,7 +449,6 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -526,8 +519,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -535,7 +528,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="last",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -578,8 +571,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -587,7 +580,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -624,15 +617,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
drop_system_message=True,
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
@@ -668,8 +661,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
roles=custom_roles,
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -741,8 +733,7 @@ class TestChatTemplateConfigurations:
|
||||
),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -911,6 +902,64 @@ class TestChatTemplateConfigurations:
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
LOG.debug(f"Final input_ids: {input_ids}")
|
||||
|
||||
def test_get_chat_template_variables(
|
||||
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
):
|
||||
LOG.info("Testing get_chat_template_variables")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
prompter = ChatTemplatePrompter(
|
||||
actual_tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=actual_jinja_template
|
||||
),
|
||||
message_property_mappings={"from": "role", "value": "content"},
|
||||
)
|
||||
|
||||
variables = prompter.get_chat_template_msg_variables(
|
||||
actual_jinja_template
|
||||
if actual_jinja_template
|
||||
else actual_tokenizer.get_chat_template(),
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TestJinjaTemplateAnalyzer:
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
|
||||
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
|
||||
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {
|
||||
"messages",
|
||||
"content",
|
||||
"eos_token",
|
||||
"message",
|
||||
"tools",
|
||||
"system_message",
|
||||
"loop_messages",
|
||||
"ns",
|
||||
"tool_call",
|
||||
"tool",
|
||||
"loop",
|
||||
"bos_token",
|
||||
"raise_exception",
|
||||
}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
message_vars = variables["message"]
|
||||
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
def test_message_property_access(self, basic_jinja_template_analyzer):
|
||||
"""Test that properties accessed on 'message' variable are correctly identified."""
|
||||
LOG.info("Testing message property access")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
assert "messages" in variables
|
||||
assert "message" in variables
|
||||
assert "role" in variables["message"]
|
||||
assert "content" in variables["message"]
|
||||
|
||||
def test_detailed_analysis(self, basic_jinja_template_analyzer):
|
||||
"""Test the detailed analysis of variable usage."""
|
||||
LOG.info("Testing detailed analysis")
|
||||
|
||||
analysis = basic_jinja_template_analyzer.analyze_template()
|
||||
|
||||
assert analysis["messages"]["is_iterated"] is True
|
||||
assert "role" in analysis["message"]["accessed_properties"]
|
||||
assert "content" in analysis["message"]["accessed_properties"]
|
||||
|
||||
assert analysis["add_generation_prompt"]["is_conditional"] is True
|
||||
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
|
||||
|
||||
assert not analysis["eos_token"]["is_iterated"]
|
||||
assert len(analysis["eos_token"]["accessed_properties"]) == 0
|
||||
|
||||
def test_nested_property_access(self):
|
||||
"""Test handling of nested property access."""
|
||||
LOG.info("Testing nested property access")
|
||||
|
||||
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "profile" in variables["user"]
|
||||
assert "settings" in variables["user"]
|
||||
|
||||
def test_loop_variable_handling(self):
|
||||
"""Test handling of loop variables and their properties."""
|
||||
LOG.info("Testing loop variable handling")
|
||||
|
||||
template = """
|
||||
{% for item in items %}
|
||||
{{ item.name }}
|
||||
{% for subitem in item.subitems %}
|
||||
{{ subitem.value }}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["items"]["is_iterated"]
|
||||
assert "name" in analysis["item"]["accessed_properties"]
|
||||
assert "subitems" in analysis["item"]["accessed_properties"]
|
||||
|
||||
def test_conditional_variable_usage(self):
|
||||
"""Test detection of variables used in conditional statements."""
|
||||
LOG.info("Testing conditional variable usage")
|
||||
|
||||
template = """
|
||||
{% if user.is_admin and config.debug_mode %}
|
||||
{{ debug_info }}
|
||||
{% endif %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["user"]["is_conditional"]
|
||||
assert analysis["config"]["is_conditional"]
|
||||
assert "is_admin" in analysis["user"]["accessed_properties"]
|
||||
assert "debug_mode" in analysis["config"]["accessed_properties"]
|
||||
|
||||
def test_complex_expressions(self):
|
||||
"""Test handling of complex expressions and filters."""
|
||||
LOG.info("Testing complex expressions and filters")
|
||||
|
||||
template = """
|
||||
{{ user.name | upper }}
|
||||
{{ messages | length > 0 and messages[0].content }}
|
||||
{{ data['key'].nested['value'] }}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "name" in variables["user"]
|
||||
assert "messages" in variables
|
||||
assert "content" in variables["messages"]
|
||||
assert "data" in variables
|
||||
|
||||
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
|
||||
"""Test that the basic message variables are correctly identified."""
|
||||
LOG.info("Testing basic message variables")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content"}
|
||||
|
||||
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
|
||||
"""Test that the mixtral message variables are correctly identified."""
|
||||
LOG.info("Testing mixtral message variables")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -302,3 +302,22 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_message_property_mappings(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestModelsUtils:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="")
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
@@ -116,3 +116,79 @@ class TestModelsUtils:
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
"""Test message property mapping configuration validation"""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||
|
||||
# Test legacy fields are mapped orrectly
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_field_content="content_field",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test direct message_property_mapping works
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_property_mappings={
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they match
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="same_role",
|
||||
message_property_mappings={"role": "same_role"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "same_role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they don't overlap
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_property_mappings={"content": "content_field"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test no role or content provided
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test error when legacy and new fields conflict
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="legacy_role",
|
||||
message_property_mappings={"role": "different_role"},
|
||||
)
|
||||
assert "Conflicting message role fields" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_content="legacy_content",
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user