Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-10-13 16:27:10 +05:30
committed by GitHub
14 changed files with 249 additions and 18 deletions

View File

@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt

View File

@@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -314,6 +315,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
@@ -362,7 +364,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -1,11 +1,11 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.0
transformers==4.45.1
tokenizers>=0.19.1
bitsandbytes==0.44.0
accelerate==0.34.2
datasets==2.21.0
peft==0.13.2
transformers==4.45.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.0
datasets==3.0.1
deepspeed==0.14.4
pydantic==2.6.3
addict
@@ -52,3 +52,5 @@ lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0

60
scripts/chat_datasets.py Normal file
View File

@@ -0,0 +1,60 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,6 +30,7 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -53,7 +54,8 @@ def parse_requirements():
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
if (major, minor) >= (2, 3):
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
@@ -61,9 +63,11 @@ def parse_requirements():
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")

View File

@@ -1445,9 +1445,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)
if self.cfg.use_wandb:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)

View File

@@ -61,6 +61,9 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
@@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
)
for example in examples
]
images = [example["images"] for example in examples]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]

View File

@@ -43,7 +43,9 @@ class ChatTemplate(str, Enum):
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
@@ -53,6 +55,7 @@ class ChatTemplate(str, Enum):
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
@@ -164,6 +167,7 @@ class SFTDataset(BaseModel):
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
@@ -206,6 +210,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel):
@@ -227,6 +232,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class LoftQConfig(BaseModel):
@@ -478,6 +484,7 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None

View File

@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

View File

@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style,
processor=processor,
)

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import set_caching_enabled
from datasets import disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
disable_caching()
yield
finally:
set_caching_enabled(True)
enable_caching()
def add_position_ids(sample):

View File

@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()