Phi-3 conversation format, example training script and perplexity metric (#1582)

* phi-3 support and perplexity metric

* phi-3 chat template

* metrics updates

* chore: lint

* fix assertion on Tensor

* fix tests since tokenization happens in the metric

* fix perplexity value of shorter passage

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Brian Fitzgerald
2024-06-04 15:11:56 -05:00
committed by GitHub
parent c996881ec2
commit cf64284a04
10 changed files with 243 additions and 26 deletions

View File

@@ -474,12 +474,16 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)
if split == "train" and cfg.val_set_size:
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)
if split == "train" and val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "train"
+ "|"
@@ -488,7 +492,7 @@ def load_prepare_datasets(
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "test"
+ "|"
@@ -498,9 +502,7 @@ def load_prepare_datasets(
test_fingerprint = md5(to_hash_test)
dataset = dataset.train_test_split(
test_size=int(cfg.val_set_size)
if cfg.val_set_size == int(cfg.val_set_size)
else cfg.val_set_size,
test_size=val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
@@ -535,6 +537,10 @@ def get_dataset_wrapper(
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}
LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features