Compare commits

...

13 Commits

Author SHA1 Message Date
Wing Lian
e9a1f288cf support for custom trainer_cls from config 2024-05-14 18:57:53 -04:00
Ali Mosavian
1e1921b794 FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584)
* FIX: TRL trainer preprocessing step was running in one process

* FIX: max_length and max_prompt_length was not being sent to ORPOTrainer

* FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour

* FIX: Removed change from a different PR

* FIX: Black fix

* explicitly set max prompt len for orpo config

---------

Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-14 08:51:17 -04:00
Wing Lian
1634ac82e0 make sure to save on the last step (#1615) 2024-05-14 08:48:39 -04:00
Wing Lian
02982733ec fix attention mask collation (#1603) 2024-05-14 08:17:30 -04:00
Chansung Park
5d97e65f95 add dstack section (#1612) [skip ci]
* add dstack section

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-14 08:13:45 -04:00
Wing Lian
2147cf6837 Llama3 dpo (#1610)
* add dpo llama3

* fix dpo bos and eos

* bos token gets added automatically by the tokenizer

* explicit <|end_of_text|> not needed, as eot_id is sufficient

---------

Co-authored-by: Nero10578 <owenarliawan@gmail.com>
2024-05-11 18:29:03 -04:00
Ram
50421c8b1d feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)
* Add prompt strategies

* Update modified URL

* Update modified URL

* Update fastchat_conversation_turns.py

* Update register function

* Remove extra /n for system prompt

* Fix return

* Fix BOS

* Update requirements, pylint

* Linting

* Linting

* fix tuples, make sure to set system message in template

* tests for llama3 tokenization

* fix conditionals for loading chat template

---------

Co-authored-by: Ram <ram@Rams-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-11 00:08:04 -04:00
Antoni-Joan Solergibert
b32c08f8cc adding llama3 fastchat conversation monkeypatch (#1539)
* adding llama3 fastchat conversation monkeypatch

* Updated conversation turns to work with PR3259 of FastChat

* fixed bos token

* bump fastchat version

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-10 10:40:05 -04:00
Wing Lian
fff06af8d0 ignore the fsdp_config section too (#1606) [skip ci] 2024-05-09 13:30:39 -04:00
Wing Lian
796a085b2f make sure to save the lora adapter at the end of RL/dpo training (#1573) 2024-05-08 10:39:33 -04:00
Wing Lian
cb78a36374 improve tool handling roles (#1587) 2024-05-07 11:30:40 -04:00
NanoCode012
8b9c15b17f feat: exclude mamba blocks for jamba (#1578) 2024-05-07 22:52:57 +09:00
Chirag Jain
9e1480e9ca Pass deepspeed and fsdp as None explicitly when merging adapters to allow custom device_map (#1575) 2024-05-07 22:47:55 +09:00
17 changed files with 369 additions and 84 deletions

View File

@@ -34,6 +34,7 @@ Features:
- [Mac](#mac) - [Mac](#mac)
- [Google Colab](#google-colab) - [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
@@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
``` ```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.

View File

@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard

View File

@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )
@@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg.flash_attention = False parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -19,7 +19,10 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: if parsed_cfg.chat_template == "chatml":
LOG.info( if parsed_cfg.default_system_message:
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" LOG.info(
) f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
register_chatml_template(parsed_cfg.default_system_message) )
else: register_chatml_template(parsed_cfg.default_system_message)
register_chatml_template() else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (

View File

@@ -19,7 +19,10 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo": if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:

View File

@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer return AxolotlMambaTrainer
if self.cfg.custom_trainer_cls:
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
return importlib.import_module(_module, _cls)
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
@@ -1526,6 +1529,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls( training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,

View File

@@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role, "" yield role, ""
return return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA: if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message: if self.system_message:
raise ValueError("Gemma chat template does not support system messages") raise ValueError("Gemma chat template does not support system messages")

View File

@@ -0,0 +1,133 @@
"""
DPO strategies for llama-3 chat template
"""
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
name="chatml", name="chatml",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"], roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
@@ -32,83 +32,63 @@ def register_chatml_template(system_message=None):
name="chatml_glaive", name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
) )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def register_llama3_template(system_message=None):
conversation = ( system_message = system_message or "You are a helpful assistant."
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None register_conv_template(
) Conversation(
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None name="llama3",
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None system_message=system_message,
strategy = SimpleShareGPTPromptTokenizingStrategy( roles=("user", "assistant"),
ShareGPTPrompterV2( sep_style=SeparatorStyle.LLAMA3,
conversation=conversation, sep="",
role_key_model=field_model, stop_str="<|eot_id|>",
role_key_human=field_human, stop_token_ids=[128001, 128009],
roles=roles, )
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def load_guanaco(tokenizer, cfg): def build_loader(
return GuanacoShareGPTPromptTokenizingStrategy( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
ShareGPTPrompterV2(), prompter_cls: Type["ShareGPTPrompterV2"],
tokenizer, default_conversation: Optional[str] = None,
cfg.train_on_inputs, ):
cfg.sequence_len, def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -158,7 +138,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -209,3 +191,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}", "chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>", "zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}", "vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
} }
@@ -348,7 +349,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -212,6 +212,10 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
class SaveModelOnTrainEndCallback(TrainerCallback): class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end""" """Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs self, args, state, control, **kwargs
): ):

View File

@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
} }
if user_choice in templates: if user_choice in templates:

View File

@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "attention_mask": if feature == "attention_mask":
if self.multipack_attn: if self.multipack_attn:
arrays = [ arrays = [
(i + 1) * np.array(item[feature]) (i + 1) * np.array(item)
for i, item in enumerate(features[feature]) for i, item in enumerate(features[feature])
if feature in item
] ]
else: else:
arrays = [(1) * np.array(item) for item in features[feature]] arrays = [(1) * np.array(item) for item in features[feature]]

View File

@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
inst = "inst" # pylint: disable=invalid-name inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -516,6 +517,9 @@ class AxolotlInputConfig(
sequence_len: int = Field(default=512) sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512, metadata={"help": "maximum prompt length for RL training"}
)
sample_packing: Optional[bool] = None sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
@@ -557,6 +561,8 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None torch_compile_backend: Optional[str] = None
custom_trainer_cls: Optional[str] = None
max_steps: Optional[int] = None max_steps: Optional[int] = None
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None warmup_ratio: Optional[float] = None

View File

@@ -1,4 +1,5 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -504,6 +505,9 @@ def load_model(
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )

View File

@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
GlaiveShareGPTPromptTokenizingStrategy, GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template, register_chatml_template,
register_llama3_template,
) )
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import ShareGPTPrompterV2
register_chatml_template() register_chatml_template()
register_llama3_template()
@pytest.fixture(name="sharegpt_dataset") @pytest.fixture(name="sharegpt_dataset")
@@ -115,7 +117,53 @@ def fixture_tokenizer():
return tokenizer return tokenizer
class TestSharegpt: @pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
class TestSharegptLlama3:
"""Test class for ShareGPT style datasets with llama-3 prompts"""
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="llama3",
role_key_model=None,
role_key_human=None,
),
llama3_tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
271, 31724, 128009, # sys prompt, eot
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
class TestSharegptChatML:
""" """
Test class for sharegpt prompter Test class for sharegpt prompter
""" """