more dpo fixes for dataset loading and docs (#1185) [skip ci]
* more dpo fixes for dataset loading and docs * preprocess dpo datasets
This commit is contained in:
10
docs/rlhf.md
10
docs/rlhf.md
@@ -34,6 +34,16 @@ datasets:
|
|||||||
rl: ipo
|
rl: ipo
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using local dataset files
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- ds_type: json
|
||||||
|
data_files:
|
||||||
|
- orca_rlhf.jsonl
|
||||||
|
split: train
|
||||||
|
type: chatml.intel
|
||||||
|
```
|
||||||
|
|
||||||
#### Trl autounwrap for peft
|
#### Trl autounwrap for peft
|
||||||
|
|
||||||
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from axolotl.cli import (
|
|||||||
check_user_token,
|
check_user_token,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
load_datasets,
|
load_datasets,
|
||||||
|
load_rl_datasets,
|
||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
@@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
if parsed_cfg.rl:
|
||||||
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
else:
|
||||||
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||||
|
|||||||
@@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
|
if self.cfg.remove_unused_columns is not None:
|
||||||
|
training_args_kwargs[
|
||||||
|
"remove_unused_columns"
|
||||||
|
] = self.cfg.remove_unused_columns
|
||||||
|
else:
|
||||||
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs[
|
||||||
@@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
remove_unused_columns=False,
|
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import yaml
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -853,6 +854,41 @@ def encode_packed_pretraining(
|
|||||||
return chunked_data
|
return chunked_data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_path(ds_hash, cfg):
|
||||||
|
prepared_ds_path = (
|
||||||
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
|
if cfg.dataset_prepared_path
|
||||||
|
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepared_ds_path
|
||||||
|
|
||||||
|
|
||||||
|
def _load_preprocessed_ds(cfg, sub_cfg):
|
||||||
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||||
|
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||||
|
dataset = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.dataset_prepared_path
|
||||||
|
and any(prepared_ds_path.glob("*"))
|
||||||
|
and not cfg.is_preprocess
|
||||||
|
):
|
||||||
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||||
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||||
|
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||||
|
|
||||||
|
if cfg.is_preprocess and is_main_process():
|
||||||
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_dpo_datasets(cfg):
|
def load_prepare_dpo_datasets(cfg):
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
@@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
return concatenate_datasets(split_datasets)
|
return concatenate_datasets(split_datasets)
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset = load_split(cfg.datasets, cfg)
|
train_is_preprocessed = False
|
||||||
|
eval_is_preprocessed = False
|
||||||
|
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
||||||
|
train_is_preprocessed = True
|
||||||
|
else:
|
||||||
|
train_dataset = load_split(cfg.datasets, cfg)
|
||||||
|
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
if cfg.test_datasets:
|
if cfg.test_datasets:
|
||||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
||||||
|
eval_is_preprocessed = True
|
||||||
|
else:
|
||||||
|
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||||
if not eval_dataset:
|
if not eval_dataset:
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
|
if not train_is_preprocessed:
|
||||||
|
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||||
|
if eval_dataset and not eval_is_preprocessed:
|
||||||
|
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|||||||
Reference in New Issue
Block a user