Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-10-13 16:27:10 +05:30
committed by GitHub
14 changed files with 249 additions and 18 deletions

View File

@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation # fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ... - path: ...
type: sharegpt type: sharegpt

View File

@@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -314,6 +315,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it # mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it # Comet configuration if you're using it
@@ -362,7 +364,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf] eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -1,11 +1,11 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.0 peft==0.13.2
transformers==4.45.1 transformers==4.45.2
tokenizers>=0.19.1 tokenizers>=0.20.1
bitsandbytes==0.44.0 bitsandbytes==0.44.1
accelerate==0.34.2 accelerate==1.0.0
datasets==2.21.0 datasets==3.0.1
deepspeed==0.14.4 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
@@ -52,3 +52,5 @@ lm_eval==0.4.4
langdetect==1.0.9 langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.5.0

60
scripts/chat_datasets.py Normal file
View File

@@ -0,0 +1,60 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,6 +30,7 @@ def parse_requirements():
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
@@ -53,7 +54,8 @@ def parse_requirements():
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27") _install_requires.append("xformers>=0.0.27")
if (major, minor) >= (2, 3): elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")
@@ -61,9 +63,11 @@ def parse_requirements():
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27") _install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2): elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1") _install_requires.append("xformers>=0.0.25.post1")
else: else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1") _install_requires.append("xformers>=0.0.23.post1")

View File

@@ -1445,9 +1445,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to.append("comet_ml") report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = ( if self.cfg.use_wandb:
self.cfg.wandb_name if self.cfg.use_wandb else None training_arguments_kwargs["run_name"] = self.cfg.wandb_name
) elif self.cfg.use_mlflow:
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )

View File

@@ -61,6 +61,9 @@ def build_loader(
default_conversation: Optional[str] = None, default_conversation: Optional[str] = None,
): ):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = ( conversation = (
ds_cfg["conversation"] ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg if ds_cfg and "conversation" in ds_cfg

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
@@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
) )
for example in examples for example in examples
] ]
images = [example["images"] for example in examples] images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0: if max_images > 0:
images = [img_batch[:max_images] for img_batch in images] images = [img_batch[:max_images] for img_batch in images]

View File

@@ -43,7 +43,9 @@ class ChatTemplate(str, Enum):
alpaca = "alpaca" # pylint: disable=invalid-name alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name
@@ -53,6 +55,7 @@ class ChatTemplate(str, Enum):
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
@@ -164,6 +167,7 @@ class SFTDataset(BaseModel):
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -206,6 +210,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel): class UserDefinedKTOType(BaseModel):
@@ -227,6 +232,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -478,6 +484,7 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None mlflow_experiment_name: Optional[str] = None
mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None hf_mlflow_log_artifacts: Optional[bool] = None

View File

@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"], ds_cfg["path"],
split=ds_cfg["split"], split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
) )
split_datasets.insert(i, ds) split_datasets.insert(i, ds)

View File

@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs, **load_ds_kwargs,
) )
elif ds_from_cloud and remote_file_system: elif ds_from_cloud and remote_file_system:
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=config_dataset.data_files,
revision=config_dataset.revision,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
revision=config_dataset.revision,
) )
) )
else: else:
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset, config_dataset=config_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
cfg=cfg, cfg=cfg,
dataset=ds,
d_base_type=d_base_type, d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style, d_prompt_style=d_prompt_style,
processor=processor, processor=processor,
) )

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.cuda import torch.cuda
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import set_caching_enabled from datasets import disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager @contextmanager
def disable_datasets_caching(): def disable_datasets_caching():
try: try:
set_caching_enabled(False) disable_caching()
yield yield
finally: finally:
set_caching_enabled(True) enable_caching()
def add_position_ids(sample): def add_position_ids(sample):

View File

@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()