Add Support for revision Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912)

* Add support for `revision` dataset parameter

* only use revision on hf hub backed datasets

* use revision tied to head

* set download to use revision

* feat: add config to model validator class

* feat: add revision config to RL and tests for it

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Thomas Cleberg
2024-10-11 12:32:50 -05:00
committed by GitHub
parent 2fbc6b0c64
commit e73b8dff8d
5 changed files with 148 additions and 1 deletions

View File

@@ -125,6 +125,7 @@ class SFTDataset(BaseModel):
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class UserDefinedDPOType(BaseModel):
@@ -146,6 +147,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel):
@@ -167,6 +169,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class RLType(str, Enum):

View File

@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

View File

@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style,
processor=processor,
)