Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
@@ -46,3 +46,9 @@ gcsfs>=2024.5.0
|
|||||||
trl==0.9.6
|
trl==0.9.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
# lm eval harness
|
||||||
|
lm_eval==0.4.4
|
||||||
|
langdetect==1.0.9
|
||||||
|
immutabledict==4.2.0
|
||||||
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|||||||
@@ -55,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts")
|
|||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
AXOLOTL_LOGO = """
|
||||||
|
#@@ #@@ @@# @@#
|
||||||
|
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||||
|
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||||
|
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||||
|
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||||
|
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||||
|
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||||
|
@@@@ @@@@@@@@@@@@@@@@
|
||||||
|
"""
|
||||||
|
|
||||||
def print_axolotl_text_art(suffix=None):
|
|
||||||
|
def print_legacy_axolotl_text_art(suffix=None):
|
||||||
font = "nancyj"
|
font = "nancyj"
|
||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
@@ -69,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
print_dep_versions()
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def print_axolotl_text_art(
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
print(AXOLOTL_LOGO)
|
||||||
|
|
||||||
|
|
||||||
def print_dep_versions():
|
def print_dep_versions():
|
||||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
max_len = max(len(pkg) for pkg in packages)
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -20,6 +18,7 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
register_llama3_template,
|
register_llama3_template,
|
||||||
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return do_train(parsed_cfg, parsed_cli_args)
|
return do_train(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
def do_train(cfg, cli_args) -> None:
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -159,6 +159,29 @@ class BasePlugin:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def post_train(self, cfg, model):
|
||||||
|
"""
|
||||||
|
Performs actions after training is complete.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The axolotl configuration
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
"""
|
||||||
|
Performs actions after training is complete and the model is unloaded.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def load_plugin(plugin_name: str) -> BasePlugin:
|
def load_plugin(plugin_name: str) -> BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -381,3 +404,17 @@ class PluginManager:
|
|||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
"""
|
||||||
|
Calls the post_train_unload method of all registered plugins.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins:
|
||||||
|
plugin.post_train_unload(cfg)
|
||||||
|
|||||||
13
src/axolotl/integrations/lm_eval/README.md
Normal file
13
src/axolotl/integrations/lm_eval/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# LM Eval Harness
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.lm_eval.LMEvalPlugin
|
||||||
|
|
||||||
|
lm_eval_tasks:
|
||||||
|
- gsm8k
|
||||||
|
- hellaswag
|
||||||
|
- arc_easy
|
||||||
|
```
|
||||||
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
Module for the Plugin for LM Eval Harness
|
||||||
|
"""
|
||||||
|
import subprocess # nosec
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class LMEvalPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for LM Evaluation Harness integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg):
|
||||||
|
tasks = ",".join(cfg.lm_eval_tasks)
|
||||||
|
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||||
|
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||||
|
output_path = cfg.output_dir
|
||||||
|
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||||
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
subprocess.run( # nosec
|
||||||
|
[
|
||||||
|
"lm_eval",
|
||||||
|
"--model",
|
||||||
|
"hf",
|
||||||
|
"--model_args",
|
||||||
|
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||||
|
"--tasks",
|
||||||
|
tasks,
|
||||||
|
"--batch_size",
|
||||||
|
str(cfg.lm_eval_batch_size),
|
||||||
|
"--output_path",
|
||||||
|
output_path,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
15
src/axolotl/integrations/lm_eval/args.py
Normal file
15
src/axolotl/integrations/lm_eval/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Module for handling lm eval harness input arguments.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LMEvalArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for lm eval harness
|
||||||
|
"""
|
||||||
|
|
||||||
|
lm_eval_tasks: List[str] = []
|
||||||
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
@@ -30,7 +30,9 @@ _CHAT_TEMPLATES = {
|
|||||||
|
|
||||||
CHAT_TEMPLATES = {
|
CHAT_TEMPLATES = {
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
||||||
|
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
||||||
|
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
|||||||
@@ -1051,6 +1051,26 @@ class AxolotlInputConfig(
|
|||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data.get("do_bench_eval") and not (
|
||||||
|
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_test_datasets_bench(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("do_bench_eval")
|
||||||
|
and not data.get("test_datasets")
|
||||||
|
and not data.get("val_set_size")
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||||
|
)
|
||||||
|
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
Reference in New Issue
Block a user