Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
@@ -46,3 +46,9 @@ gcsfs>=2024.5.0
|
||||
trl==0.9.6
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.4
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
@@ -55,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
AXOLOTL_LOGO = """
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
"""
|
||||
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
|
||||
def print_legacy_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
@@ -69,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
|
||||
print_dep_versions()
|
||||
|
||||
|
||||
def print_axolotl_text_art(
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
|
||||
def print_dep_versions():
|
||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||
max_len = max(len(pkg) for pkg in packages)
|
||||
|
||||
@@ -3,13 +3,11 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -20,6 +18,7 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
def do_train(cfg, cli_args) -> None:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -159,6 +159,29 @@ class BasePlugin:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
|
||||
def post_train(self, cfg, model):
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
"""
|
||||
@@ -381,3 +404,17 @@ class PluginManager:
|
||||
for plugin in self.plugins:
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
return callbacks
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
13
src/axolotl/integrations/lm_eval/README.md
Normal file
13
src/axolotl/integrations/lm_eval/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# LM Eval Harness
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lm_eval.LMEvalPlugin
|
||||
|
||||
lm_eval_tasks:
|
||||
- gsm8k
|
||||
- hellaswag
|
||||
- arc_easy
|
||||
```
|
||||
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
|
||||
class LMEvalPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for LM Evaluation Harness integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
15
src/axolotl/integrations/lm_eval/args.py
Normal file
15
src/axolotl/integrations/lm_eval/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Module for handling lm eval harness input arguments.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LMEvalArgs(BaseModel):
|
||||
"""
|
||||
Input args for lm eval harness
|
||||
"""
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
@@ -30,7 +30,9 @@ _CHAT_TEMPLATES = {
|
||||
|
||||
CHAT_TEMPLATES = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
||||
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
||||
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
|
||||
@@ -1051,6 +1051,26 @@ class AxolotlInputConfig(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||
):
|
||||
raise ValueError(
|
||||
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_test_datasets_bench(cls, data):
|
||||
if (
|
||||
data.get("do_bench_eval")
|
||||
and not data.get("test_datasets")
|
||||
and not data.get("val_set_size")
|
||||
):
|
||||
LOG.warning(
|
||||
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||
)
|
||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user