Compare commits
13 Commits
nca-pair
...
custom-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a1f288cf | ||
|
|
1e1921b794 | ||
|
|
1634ac82e0 | ||
|
|
02982733ec | ||
|
|
5d97e65f95 | ||
|
|
2147cf6837 | ||
|
|
50421c8b1d | ||
|
|
b32c08f8cc | ||
|
|
fff06af8d0 | ||
|
|
796a085b2f | ||
|
|
cb78a36374 | ||
|
|
8b9c15b17f | ||
|
|
9e1480e9ca |
37
README.md
37
README.md
@@ -34,6 +34,7 @@ Features:
|
|||||||
- [Mac](#mac)
|
- [Mac](#mac)
|
||||||
- [Google Colab](#google-colab)
|
- [Google Colab](#google-colab)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
|
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
@@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Launching on public clouds via dstack
|
||||||
|
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
||||||
|
|
||||||
|
Write a job description in YAML as below:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# dstack.yaml
|
||||||
|
type: task
|
||||||
|
|
||||||
|
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
|
||||||
|
|
||||||
|
env:
|
||||||
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
- WANDB_API_KEY
|
||||||
|
|
||||||
|
commands:
|
||||||
|
- accelerate launch -m axolotl.cli.train config.yaml
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- 6006
|
||||||
|
|
||||||
|
resources:
|
||||||
|
gpu:
|
||||||
|
memory: 24GB..
|
||||||
|
count: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install dstack
|
||||||
|
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
||||||
|
```
|
||||||
|
|
||||||
|
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
|
deepspeed=None,
|
||||||
|
fsdp=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg.flash_attention = False
|
parsed_cfg.flash_attention = False
|
||||||
parsed_cfg.deepspeed = None
|
parsed_cfg.deepspeed = None
|
||||||
parsed_cfg.fsdp = None
|
parsed_cfg.fsdp = None
|
||||||
|
parsed_cfg.fsdp_config = None
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
if parsed_cfg.chat_template == "chatml":
|
||||||
LOG.info(
|
if parsed_cfg.default_system_message:
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
LOG.info(
|
||||||
)
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
)
|
||||||
else:
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
register_chatml_template()
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
elif parsed_cfg.chat_template == "llama3":
|
||||||
|
if parsed_cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(parsed_cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
|
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
|
if self.cfg.custom_trainer_cls:
|
||||||
|
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
|
||||||
|
return importlib.import_module(_module, _cls)
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -1526,6 +1529,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
|
|||||||
@@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||||
|
if self.system_message:
|
||||||
|
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||||
|
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||||
|
else:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.GEMMA:
|
if self.sep_style == SeparatorStyle.GEMMA:
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
raise ValueError("Gemma chat template does not support system messages")
|
raise ValueError("Gemma chat template does not support system messages")
|
||||||
|
|||||||
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for llama-3 chat template
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def icr(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
chatml transforms for datasets with system, input, chosen, rejected
|
||||||
|
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
For Intel Orca DPO Pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(
|
||||||
|
cfg, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for ultrafeedback binarized conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml",
|
name="chatml",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
@@ -32,83 +32,63 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml_glaive",
|
name="chatml_glaive",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def register_llama3_template(system_message=None):
|
||||||
conversation = (
|
system_message = system_message or "You are a helpful assistant."
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
register_conv_template(
|
||||||
)
|
Conversation(
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
name="llama3",
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
system_message=system_message,
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
roles=("user", "assistant"),
|
||||||
ShareGPTPrompterV2(
|
sep_style=SeparatorStyle.LLAMA3,
|
||||||
conversation=conversation,
|
sep="",
|
||||||
role_key_model=field_model,
|
stop_str="<|eot_id|>",
|
||||||
role_key_human=field_human,
|
stop_token_ids=[128001, 128009],
|
||||||
roles=roles,
|
)
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_guanaco(tokenizer, cfg):
|
def build_loader(
|
||||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||||
ShareGPTPrompterV2(),
|
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||||
tokenizer,
|
default_conversation: Optional[str] = None,
|
||||||
cfg.train_on_inputs,
|
):
|
||||||
cfg.sequence_len,
|
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
)
|
conversation = (
|
||||||
|
ds_cfg["conversation"]
|
||||||
|
if ds_cfg and "conversation" in ds_cfg
|
||||||
|
else default_conversation
|
||||||
|
)
|
||||||
|
field_human = (
|
||||||
|
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_model = (
|
||||||
|
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
)
|
||||||
|
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
strategy = tokenization_strategy_cls(
|
||||||
|
prompter_cls(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
return _load
|
||||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"]
|
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
|
||||||
else "chatml_glaive"
|
|
||||||
)
|
|
||||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation=conversation),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -158,7 +138,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
|
SimpleShareGPTPromptTokenizingStrategy
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
"""
|
"""
|
||||||
@@ -209,3 +191,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
|
|||||||
conversation = merge_consecutive_messages(conversation)
|
conversation = merge_consecutive_messages(conversation)
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_ultrachat = build_loader(
|
||||||
|
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||||
|
)
|
||||||
|
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_glaive = build_loader(
|
||||||
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
|
ShareGPTPrompterV2,
|
||||||
|
default_conversation="chatml_glaive",
|
||||||
|
)
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
|||||||
"chatml": "<|im_start|>{ROLE}",
|
"chatml": "<|im_start|>{ROLE}",
|
||||||
"zephyr": "<|{ROLE}|>",
|
"zephyr": "<|{ROLE}|>",
|
||||||
"vicuna_v1.1": "{ROLE}",
|
"vicuna_v1.1": "{ROLE}",
|
||||||
|
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -348,7 +349,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
if (
|
||||||
|
role != "assistant"
|
||||||
|
): # back to back assistant calls may be okay for tool calls
|
||||||
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
|
|||||||
@@ -212,6 +212,10 @@ def train(
|
|||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
|
trainer.model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
|||||||
@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||||
"""Callback to save model on train end"""
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
|
def on_step_end( # pylint: disable=unused-argument
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Save
|
||||||
|
if state.global_step >= state.max_steps:
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
|||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
if self.multipack_attn:
|
if self.multipack_attn:
|
||||||
arrays = [
|
arrays = [
|
||||||
(i + 1) * np.array(item[feature])
|
(i + 1) * np.array(item)
|
||||||
for i, item in enumerate(features[feature])
|
for i, item in enumerate(features[feature])
|
||||||
if feature in item
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
|
|||||||
inst = "inst" # pylint: disable=invalid-name
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -516,6 +517,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
min_sample_len: Optional[int] = None
|
min_sample_len: Optional[int] = None
|
||||||
|
max_prompt_len: int = Field(
|
||||||
|
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||||
|
)
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
@@ -557,6 +561,8 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
|
||||||
|
custom_trainer_cls: Optional[str] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -504,6 +505,9 @@ def load_model(
|
|||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
|
# Exclude mamba blocks from int8 quantization for jamba
|
||||||
|
if cfg.model_config_type == "jamba":
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
|||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
SimpleShareGPTPromptTokenizingStrategy,
|
SimpleShareGPTPromptTokenizingStrategy,
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
@pytest.fixture(name="sharegpt_dataset")
|
||||||
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestSharegpt:
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptLlama3:
|
||||||
|
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
||||||
|
|
||||||
|
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
||||||
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation="llama3",
|
||||||
|
role_key_model=None,
|
||||||
|
role_key_human=None,
|
||||||
|
),
|
||||||
|
llama3_tokenizer,
|
||||||
|
False, # train_on_inputs
|
||||||
|
2048, # sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
strategy, sharegpt_dataset, process_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
128000, # bos
|
||||||
|
128006, 9125, 128007, # system header
|
||||||
|
271, 31724, 128009, # sys prompt, eot
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptChatML:
|
||||||
"""
|
"""
|
||||||
Test class for sharegpt prompter
|
Test class for sharegpt prompter
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user