Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
# fastchat conversation
|
# fastchat conversation (deprecation soon, use chat_template)
|
||||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
- path: ...
|
- path: ...
|
||||||
type: sharegpt
|
type: sharegpt
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ datasets:
|
|||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
@@ -314,6 +315,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|||||||
# mlflow configuration if you're using it
|
# mlflow configuration if you're using it
|
||||||
mlflow_tracking_uri: # URI to mlflow
|
mlflow_tracking_uri: # URI to mlflow
|
||||||
mlflow_experiment_name: # Your experiment name
|
mlflow_experiment_name: # Your experiment name
|
||||||
|
mlflow_run_name: # Your run name
|
||||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||||
|
|
||||||
# Comet configuration if you're using it
|
# Comet configuration if you're using it
|
||||||
@@ -362,7 +364,7 @@ max_steps:
|
|||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.0
|
peft==0.13.2
|
||||||
transformers==4.45.1
|
transformers==4.45.2
|
||||||
tokenizers>=0.19.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.0
|
bitsandbytes==0.44.1
|
||||||
accelerate==0.34.2
|
accelerate==1.0.0
|
||||||
datasets==2.21.0
|
datasets==3.0.1
|
||||||
deepspeed==0.14.4
|
deepspeed==0.14.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
@@ -52,3 +52,5 @@ lm_eval==0.4.4
|
|||||||
langdetect==1.0.9
|
langdetect==1.0.9
|
||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
|
torchao==0.5.0
|
||||||
|
|||||||
60
scripts/chat_datasets.py
Normal file
60
scripts/chat_datasets.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
helper script to parse chat datasets into a usable yaml
|
||||||
|
"""
|
||||||
|
import click
|
||||||
|
import yaml
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("dataset", type=str)
|
||||||
|
@click.option("--split", type=str, default="train")
|
||||||
|
def parse_dataset(dataset=None, split="train"):
|
||||||
|
ds_cfg = {}
|
||||||
|
ds_cfg["path"] = dataset
|
||||||
|
ds_cfg["split"] = split
|
||||||
|
ds_cfg["type"] = "chat_template"
|
||||||
|
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
|
||||||
|
|
||||||
|
dataset = load_dataset(dataset, split=split)
|
||||||
|
features = dataset.features
|
||||||
|
feature_keys = features.keys()
|
||||||
|
field_messages = None
|
||||||
|
for key in ["conversation", "conversations", "messages"]:
|
||||||
|
if key in feature_keys:
|
||||||
|
field_messages = key
|
||||||
|
break
|
||||||
|
if not field_messages:
|
||||||
|
raise ValueError(
|
||||||
|
f'No conversation field found in dataset: {", ".join(feature_keys)}'
|
||||||
|
)
|
||||||
|
ds_cfg["field_messages"] = field_messages
|
||||||
|
|
||||||
|
message_fields = features["conversations"][0].keys()
|
||||||
|
message_field_role = None
|
||||||
|
for key in ["from", "role"]:
|
||||||
|
if key in message_fields:
|
||||||
|
message_field_role = key
|
||||||
|
break
|
||||||
|
if not message_field_role:
|
||||||
|
raise ValueError(
|
||||||
|
f'No role field found in messages: {", ".join(message_fields)}'
|
||||||
|
)
|
||||||
|
ds_cfg["message_field_role"] = message_field_role
|
||||||
|
|
||||||
|
message_field_content = None
|
||||||
|
for key in ["content", "text", "value"]:
|
||||||
|
if key in message_fields:
|
||||||
|
message_field_content = key
|
||||||
|
break
|
||||||
|
if not message_field_content:
|
||||||
|
raise ValueError(
|
||||||
|
f'No content field found in messages: {", ".join(message_fields)}'
|
||||||
|
)
|
||||||
|
ds_cfg["message_field_content"] = message_field_content
|
||||||
|
|
||||||
|
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parse_dataset()
|
||||||
6
setup.py
6
setup.py
@@ -30,6 +30,7 @@ def parse_requirements():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# don't install xformers on MacOS
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
@@ -53,7 +54,8 @@ def parse_requirements():
|
|||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
if (major, minor) >= (2, 3):
|
elif (major, minor) >= (2, 3):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.26.post1")
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
@@ -61,9 +63,11 @@ def parse_requirements():
|
|||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
else:
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
|
|||||||
@@ -1445,9 +1445,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
report_to.append("comet_ml")
|
report_to.append("comet_ml")
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
if self.cfg.use_wandb:
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
)
|
elif self.cfg.use_mlflow:
|
||||||
|
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
|
else:
|
||||||
|
training_arguments_kwargs["run_name"] = None
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ def build_loader(
|
|||||||
default_conversation: Optional[str] = None,
|
default_conversation: Optional[str] = None,
|
||||||
):
|
):
|
||||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
LOG.warning(
|
||||||
|
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
|
||||||
|
)
|
||||||
conversation = (
|
conversation = (
|
||||||
ds_cfg["conversation"]
|
ds_cfg["conversation"]
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
if ds_cfg and "conversation" in ds_cfg
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
from transformers.data.data_collator import DataCollatorMixin
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
@@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
)
|
)
|
||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
images = [example["images"] for example in examples]
|
images = [
|
||||||
|
Image.open(example["images"])
|
||||||
|
if isinstance(example["images"], str)
|
||||||
|
else example["images"]
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
|
||||||
if max_images > 0:
|
if max_images > 0:
|
||||||
images = [img_batch[:max_images] for img_batch in images]
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ class ChatTemplate(str, Enum):
|
|||||||
|
|
||||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
chatml = "chatml" # pylint: disable=invalid-name
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
inst = "inst" # pylint: disable=invalid-name
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
@@ -53,6 +55,7 @@ class ChatTemplate(str, Enum):
|
|||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
jamba = "jamba" # pylint: disable=invalid-name
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
jinja = "jinja" # pylint: disable=invalid-name
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +167,7 @@ class SFTDataset(BaseModel):
|
|||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -206,6 +210,7 @@ class DPODataset(BaseModel):
|
|||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedKTOType(BaseModel):
|
class UserDefinedKTOType(BaseModel):
|
||||||
@@ -227,6 +232,7 @@ class KTODataset(BaseModel):
|
|||||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -478,6 +484,7 @@ class MLFlowConfig(BaseModel):
|
|||||||
use_mlflow: Optional[bool] = None
|
use_mlflow: Optional[bool] = None
|
||||||
mlflow_tracking_uri: Optional[str] = None
|
mlflow_tracking_uri: Optional[str] = None
|
||||||
mlflow_experiment_name: Optional[str] = None
|
mlflow_experiment_name: Optional[str] = None
|
||||||
|
mlflow_run_name: Optional[str] = None
|
||||||
hf_mlflow_log_artifacts: Optional[bool] = None
|
hf_mlflow_log_artifacts: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
ds = load_dataset( # pylint: disable=invalid-name
|
ds = load_dataset( # pylint: disable=invalid-name
|
||||||
ds_cfg["path"],
|
ds_cfg["path"],
|
||||||
split=ds_cfg["split"],
|
split=ds_cfg["split"],
|
||||||
|
revision=ds_cfg.get("revision", None),
|
||||||
)
|
)
|
||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=config_dataset.data_files,
|
filename=config_dataset.data_files,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
elif isinstance(config_dataset.data_files, list):
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
config_dataset=config_dataset,
|
config_dataset=config_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
dataset=ds,
|
|
||||||
d_base_type=d_base_type,
|
d_base_type=d_base_type,
|
||||||
|
dataset=ds,
|
||||||
d_prompt_style=d_prompt_style,
|
d_prompt_style=d_prompt_style,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import set_caching_enabled
|
from datasets import disable_caching, enable_caching
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def disable_datasets_caching():
|
def disable_datasets_caching():
|
||||||
try:
|
try:
|
||||||
set_caching_enabled(False)
|
disable_caching()
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
set_caching_enabled(True)
|
enable_caching()
|
||||||
|
|
||||||
|
|
||||||
def add_position_ids(sample):
|
def add_position_ids(sample):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
|
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_dpo(self):
|
||||||
|
"""Verify that processing dpo data from the hub works"""
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||||
|
|
||||||
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_revision(self):
|
||||||
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_revision_with_dpo(self):
|
||||||
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"revision": "ea82cff",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||||
|
|
||||||
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
|
def test_load_local_hub_with_revision(self):
|
||||||
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||||
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=tmp_ds_path,
|
||||||
|
revision="d05c1cb",
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"ds_type": "parquet",
|
||||||
|
"type": "alpaca",
|
||||||
|
"data_files": [
|
||||||
|
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||||
|
],
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user