more dpo fixes for dataset loading and docs (#1185) [skip ci]

* more dpo fixes for dataset loading and docs

* preprocess dpo datasets
This commit is contained in:
Wing Lian
2024-01-24 14:23:55 -05:00
committed by GitHub
parent d85d4942cf
commit 5bce45f800
4 changed files with 73 additions and 4 deletions

View File

@@ -34,6 +34,16 @@ datasets:
rl: ipo rl: ipo
``` ```
#### Using local dataset files
```yaml
datasets:
- ds_type: json
data_files:
- orca_rlhf.jsonl
split: train
type: chatml.intel
```
#### Trl autounwrap for peft #### Trl autounwrap for peft
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config. Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.

View File

@@ -13,6 +13,7 @@ from axolotl.cli import (
check_user_token, check_user_token,
load_cfg, load_cfg,
load_datasets, load_datasets,
load_rl_datasets,
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
@@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cfg.rl:
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info( LOG.info(
Fore.GREEN Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"

View File

@@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["lr_scheduler_kwargs"] = ( training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None:
training_args_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[ training_args_kwargs[
@@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
training_args = TrainingArguments( training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps, max_steps=self.cfg.max_steps or total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate, learning_rate=self.cfg.learning_rate,
save_strategy="steps", save_strategy="steps",

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import yaml
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -853,6 +854,41 @@ def encode_packed_pretraining(
return chunked_data return chunked_data
def _get_path(ds_hash, cfg):
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)
return prepared_ds_path
def _load_preprocessed_ds(cfg, sub_cfg):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
dataset = None
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
return dataset
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))
def load_prepare_dpo_datasets(cfg): def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
@@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg):
return concatenate_datasets(split_datasets) return concatenate_datasets(split_datasets)
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset = load_split(cfg.datasets, cfg) train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)
eval_dataset = None eval_dataset = None
if cfg.test_datasets: if cfg.test_datasets:
eval_dataset = load_split(cfg.test_datasets, cfg) if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset: if not eval_dataset:
eval_dataset = None eval_dataset = None
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
return train_dataset, eval_dataset return train_dataset, eval_dataset