Compare commits
13 Commits
sppo
...
custom-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a1f288cf | ||
|
|
1e1921b794 | ||
|
|
1634ac82e0 | ||
|
|
02982733ec | ||
|
|
5d97e65f95 | ||
|
|
2147cf6837 | ||
|
|
50421c8b1d | ||
|
|
b32c08f8cc | ||
|
|
fff06af8d0 | ||
|
|
796a085b2f | ||
|
|
cb78a36374 | ||
|
|
8b9c15b17f | ||
|
|
9e1480e9ca |
37
README.md
37
README.md
@@ -34,6 +34,7 @@ Features:
|
|||||||
- [Mac](#mac)
|
- [Mac](#mac)
|
||||||
- [Google Colab](#google-colab)
|
- [Google Colab](#google-colab)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
|
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
@@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Launching on public clouds via dstack
|
||||||
|
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
||||||
|
|
||||||
|
Write a job description in YAML as below:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# dstack.yaml
|
||||||
|
type: task
|
||||||
|
|
||||||
|
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
|
||||||
|
|
||||||
|
env:
|
||||||
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
- WANDB_API_KEY
|
||||||
|
|
||||||
|
commands:
|
||||||
|
- accelerate launch -m axolotl.cli.train config.yaml
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- 6006
|
||||||
|
|
||||||
|
resources:
|
||||||
|
gpu:
|
||||||
|
memory: 24GB..
|
||||||
|
count: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install dstack
|
||||||
|
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
||||||
|
```
|
||||||
|
|
||||||
|
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
|
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
trl==0.8.5
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
|
deepspeed=None,
|
||||||
|
fsdp=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg.flash_attention = False
|
parsed_cfg.flash_attention = False
|
||||||
parsed_cfg.deepspeed = None
|
parsed_cfg.deepspeed = None
|
||||||
parsed_cfg.fsdp = None
|
parsed_cfg.fsdp = None
|
||||||
|
parsed_cfg.fsdp_config = None
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
if parsed_cfg.chat_template == "chatml":
|
||||||
LOG.info(
|
if parsed_cfg.default_system_message:
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
LOG.info(
|
||||||
)
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
)
|
||||||
else:
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
register_chatml_template()
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
elif parsed_cfg.chat_template == "llama3":
|
||||||
|
if parsed_cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(parsed_cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
|
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
|
if self.cfg.custom_trainer_cls:
|
||||||
|
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
|
||||||
|
return importlib.import_module(_module, _cls)
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -1526,9 +1529,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
training_args_cls = DPOConfig
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1555,8 +1558,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
elif self.cfg.rl == "sppo_hard":
|
|
||||||
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1565,7 +1566,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|||||||
@@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||||
|
if self.system_message:
|
||||||
|
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||||
|
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||||
|
else:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.GEMMA:
|
if self.sep_style == SeparatorStyle.GEMMA:
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
raise ValueError("Gemma chat template does not support system messages")
|
raise ValueError("Gemma chat template does not support system messages")
|
||||||
|
|||||||
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
133
src/axolotl/prompt_strategies/dpo/llama3.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for llama-3 chat template
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def icr(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
chatml transforms for datasets with system, input, chosen, rejected
|
||||||
|
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
For Intel Orca DPO Pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(
|
||||||
|
cfg, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for ultrafeedback binarized conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
"""
|
|
||||||
DPO strategies for mistral instruct
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def argilla_chat(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
for argilla/dpo-mix-7k conversations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml",
|
name="chatml",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
@@ -32,83 +32,63 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml_glaive",
|
name="chatml_glaive",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def register_llama3_template(system_message=None):
|
||||||
conversation = (
|
system_message = system_message or "You are a helpful assistant."
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
register_conv_template(
|
||||||
)
|
Conversation(
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
name="llama3",
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
system_message=system_message,
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
roles=("user", "assistant"),
|
||||||
ShareGPTPrompterV2(
|
sep_style=SeparatorStyle.LLAMA3,
|
||||||
conversation=conversation,
|
sep="",
|
||||||
role_key_model=field_model,
|
stop_str="<|eot_id|>",
|
||||||
role_key_human=field_human,
|
stop_token_ids=[128001, 128009],
|
||||||
roles=roles,
|
)
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_guanaco(tokenizer, cfg):
|
def build_loader(
|
||||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||||
ShareGPTPrompterV2(),
|
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||||
tokenizer,
|
default_conversation: Optional[str] = None,
|
||||||
cfg.train_on_inputs,
|
):
|
||||||
cfg.sequence_len,
|
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
)
|
conversation = (
|
||||||
|
ds_cfg["conversation"]
|
||||||
|
if ds_cfg and "conversation" in ds_cfg
|
||||||
|
else default_conversation
|
||||||
|
)
|
||||||
|
field_human = (
|
||||||
|
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_model = (
|
||||||
|
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
)
|
||||||
|
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
strategy = tokenization_strategy_cls(
|
||||||
|
prompter_cls(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
return _load
|
||||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"]
|
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
|
||||||
else "chatml_glaive"
|
|
||||||
)
|
|
||||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation=conversation),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -158,7 +138,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
|
SimpleShareGPTPromptTokenizingStrategy
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
"""
|
"""
|
||||||
@@ -209,3 +191,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
|
|||||||
conversation = merge_consecutive_messages(conversation)
|
conversation = merge_consecutive_messages(conversation)
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_ultrachat = build_loader(
|
||||||
|
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||||
|
)
|
||||||
|
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_glaive = build_loader(
|
||||||
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
|
ShareGPTPrompterV2,
|
||||||
|
default_conversation="chatml_glaive",
|
||||||
|
)
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
|||||||
"chatml": "<|im_start|>{ROLE}",
|
"chatml": "<|im_start|>{ROLE}",
|
||||||
"zephyr": "<|{ROLE}|>",
|
"zephyr": "<|{ROLE}|>",
|
||||||
"vicuna_v1.1": "{ROLE}",
|
"vicuna_v1.1": "{ROLE}",
|
||||||
|
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -348,7 +349,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
if (
|
||||||
|
role != "assistant"
|
||||||
|
): # back to back assistant calls may be okay for tool calls
|
||||||
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
|
|||||||
@@ -212,6 +212,10 @@ def train(
|
|||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
|
trainer.model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
|||||||
@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||||
"""Callback to save model on train end"""
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
|
def on_step_end( # pylint: disable=unused-argument
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Save
|
||||||
|
if state.global_step >= state.max_steps:
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
|||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
if self.multipack_attn:
|
if self.multipack_attn:
|
||||||
arrays = [
|
arrays = [
|
||||||
(i + 1) * np.array(item[feature])
|
(i + 1) * np.array(item)
|
||||||
for i, item in enumerate(features[feature])
|
for i, item in enumerate(features[feature])
|
||||||
if feature in item
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||||
|
|||||||
@@ -133,7 +133,6 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -144,6 +143,7 @@ class ChatTemplate(str, Enum):
|
|||||||
inst = "inst" # pylint: disable=invalid-name
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -517,6 +517,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
min_sample_len: Optional[int] = None
|
min_sample_len: Optional[int] = None
|
||||||
|
max_prompt_len: int = Field(
|
||||||
|
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||||
|
)
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
@@ -558,6 +561,8 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
|
||||||
|
custom_trainer_cls: Optional[str] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
@@ -575,7 +580,6 @@ class AxolotlInputConfig(
|
|||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
dpo_beta: Optional[float] = None
|
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -504,6 +505,9 @@ def load_model(
|
|||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
|
# Exclude mamba blocks from int8 quantization for jamba
|
||||||
|
if cfg.model_config_type == "jamba":
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
@@ -789,11 +793,7 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if (
|
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
||||||
cfg.adapter
|
|
||||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
|
|
||||||
and not cfg.merge_lora
|
|
||||||
):
|
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|||||||
@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
|||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
SimpleShareGPTPromptTokenizingStrategy,
|
SimpleShareGPTPromptTokenizingStrategy,
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
@pytest.fixture(name="sharegpt_dataset")
|
||||||
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestSharegpt:
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptLlama3:
|
||||||
|
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
||||||
|
|
||||||
|
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
||||||
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation="llama3",
|
||||||
|
role_key_model=None,
|
||||||
|
role_key_human=None,
|
||||||
|
),
|
||||||
|
llama3_tokenizer,
|
||||||
|
False, # train_on_inputs
|
||||||
|
2048, # sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
strategy, sharegpt_dataset, process_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
128000, # bos
|
||||||
|
128006, 9125, 128007, # system header
|
||||||
|
271, 31724, 128009, # sys prompt, eot
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptChatML:
|
||||||
"""
|
"""
|
||||||
Test class for sharegpt prompter
|
Test class for sharegpt prompter
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user