Add Support for revision Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912)
* Add support for `revision` dataset parameter * only use revision on hf hub backed datasets * use revision tied to head * set download to use revision * feat: add config to model validator class * feat: add revision config to RL and tests for it --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -90,6 +90,7 @@ datasets:
|
|||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class SFTDataset(BaseModel):
|
|||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
@@ -146,6 +147,7 @@ class DPODataset(BaseModel):
|
|||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedKTOType(BaseModel):
|
class UserDefinedKTOType(BaseModel):
|
||||||
@@ -167,6 +169,7 @@ class KTODataset(BaseModel):
|
|||||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
ds = load_dataset( # pylint: disable=invalid-name
|
ds = load_dataset( # pylint: disable=invalid-name
|
||||||
ds_cfg["path"],
|
ds_cfg["path"],
|
||||||
split=ds_cfg["split"],
|
split=ds_cfg["split"],
|
||||||
|
revision=ds_cfg.get("revision", None),
|
||||||
)
|
)
|
||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
revision=config_dataset.revision,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=config_dataset.data_files,
|
filename=config_dataset.data_files,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
elif isinstance(config_dataset.data_files, list):
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
repo_id=config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
|
revision=config_dataset.revision,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
config_dataset=config_dataset,
|
config_dataset=config_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
dataset=ds,
|
|
||||||
d_base_type=d_base_type,
|
d_base_type=d_base_type,
|
||||||
|
dataset=ds,
|
||||||
d_prompt_style=d_prompt_style,
|
d_prompt_style=d_prompt_style,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
|
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_dpo(self):
|
||||||
|
"""Verify that processing dpo data from the hub works"""
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||||
|
|
||||||
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_revision(self):
|
||||||
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
def test_load_hub_with_revision_with_dpo(self):
|
||||||
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"revision": "ea82cff",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||||
|
|
||||||
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
|
def test_load_local_hub_with_revision(self):
|
||||||
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||||
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=tmp_ds_path,
|
||||||
|
revision="d05c1cb",
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"ds_type": "parquet",
|
||||||
|
"type": "alpaca",
|
||||||
|
"data_files": [
|
||||||
|
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||||
|
],
|
||||||
|
"revision": "d05c1cb",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user