Compare commits
9 Commits
olmo-no-po
...
custom-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a1f288cf | ||
|
|
1e1921b794 | ||
|
|
1634ac82e0 | ||
|
|
02982733ec | ||
|
|
5d97e65f95 | ||
|
|
2147cf6837 | ||
|
|
50421c8b1d | ||
|
|
b32c08f8cc | ||
|
|
fff06af8d0 |
37
README.md
37
README.md
@@ -34,6 +34,7 @@ Features:
|
||||
- [Mac](#mac)
|
||||
- [Google Colab](#google-colab)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||
- [Dataset](#dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
@@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
#### Launching on public clouds via dstack
|
||||
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
||||
|
||||
Write a job description in YAML as below:
|
||||
|
||||
```yaml
|
||||
# dstack.yaml
|
||||
type: task
|
||||
|
||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
|
||||
|
||||
env:
|
||||
- HUGGING_FACE_HUB_TOKEN
|
||||
- WANDB_API_KEY
|
||||
|
||||
commands:
|
||||
- accelerate launch -m axolotl.cli.train config.yaml
|
||||
|
||||
ports:
|
||||
- 6006
|
||||
|
||||
resources:
|
||||
gpu:
|
||||
memory: 24GB..
|
||||
count: 2
|
||||
```
|
||||
|
||||
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
||||
|
||||
```bash
|
||||
pip install dstack
|
||||
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
||||
```
|
||||
|
||||
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||
|
||||
@@ -28,7 +28,7 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
parsed_cfg.flash_attention = False
|
||||
parsed_cfg.deepspeed = None
|
||||
parsed_cfg.fsdp = None
|
||||
parsed_cfg.fsdp_config = None
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
if parsed_cfg.chat_template == "chatml":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
elif parsed_cfg.chat_template == "llama3":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
|
||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
|
||||
@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.custom_trainer_cls:
|
||||
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
|
||||
return importlib.import_module(_module, _cls)
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -1526,6 +1529,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
|
||||
@@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||
if self.system_message:
|
||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
|
||||
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
DPO strategies for llama-3 chat template
|
||||
"""
|
||||
|
||||
|
||||
def argilla(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
For Intel Orca DPO Pairs
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def prompt_pairs(
|
||||
cfg, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for ultrafeedback binarized conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
@@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_llama3_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama3",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
system_message=system_message,
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str="<|eot_id|>",
|
||||
stop_token_ids=[128001, 128009],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_loader(
|
||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||
|
||||
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
||||
"chatml": "<|im_start|>{ROLE}",
|
||||
"zephyr": "<|{ROLE}|>",
|
||||
"vicuna_v1.1": "{ROLE}",
|
||||
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_save = True
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
|
||||
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||
}
|
||||
|
||||
if user_choice in templates:
|
||||
|
||||
@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if feature == "attention_mask":
|
||||
if self.multipack_attn:
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
(i + 1) * np.array(item)
|
||||
for i, item in enumerate(features[feature])
|
||||
if feature in item
|
||||
]
|
||||
else:
|
||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||
|
||||
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -516,6 +517,9 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||
)
|
||||
sample_packing: Optional[bool] = None
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
@@ -557,6 +561,8 @@ class AxolotlInputConfig(
|
||||
torch_compile: Optional[bool] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
|
||||
custom_trainer_cls: Optional[str] = None
|
||||
|
||||
max_steps: Optional[int] = None
|
||||
warmup_steps: Optional[int] = None
|
||||
warmup_ratio: Optional[float] = None
|
||||
|
||||
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
SimpleShareGPTPromptTokenizingStrategy,
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
register_chatml_template()
|
||||
register_llama3_template()
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSharegpt:
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
tokenizer.eos_token = "<|eot_id|>"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSharegptLlama3:
|
||||
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
||||
|
||||
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="llama3",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
271, 31724, 128009, # sys prompt, eot
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatML:
|
||||
"""
|
||||
Test class for sharegpt prompter
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user