Reward model (#1879)
This commit is contained in:
63
examples/gemma2/reward-model.yaml
Normal file
63
examples/gemma2/reward-model.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
remove_unused_columns: false
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -43,8 +43,10 @@ from trl import (
|
||||
KTOTrainer,
|
||||
ORPOConfig,
|
||||
ORPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
@@ -301,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
@@ -398,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
num_epochs=1,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
super().__init__(*_args, **kwargs)
|
||||
@@ -1039,6 +1046,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -1214,6 +1229,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -1553,6 +1570,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
@@ -1596,10 +1616,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args_cls = (
|
||||
AxolotlTrainingArguments
|
||||
if not self.cfg.reward_model
|
||||
else AxolotlRewardConfig
|
||||
)
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
@@ -1621,10 +1644,24 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
if eval_data_collator := self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
):
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["eval_data_collator"] = eval_data_collator
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
@@ -1632,16 +1669,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
eval_data_collator=self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
@@ -1675,9 +1703,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
|
||||
10
src/axolotl/prompt_strategies/bradley_terry/README.md
Normal file
10
src/axolotl/prompt_strategies/bradley_terry/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
### example yaml
|
||||
|
||||
```yaml
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
```
|
||||
35
src/axolotl/prompt_strategies/bradley_terry/__init__.py
Normal file
35
src/axolotl/prompt_strategies/bradley_terry/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
# pylint: disable=duplicate-code
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(
|
||||
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
else:
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
return None
|
||||
88
src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Normal file
88
src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Bradley-Terry model with chat template prompt strategy.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
ChatTemplateStrategy,
|
||||
)
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
"""
|
||||
|
||||
:param prompt: the actual row of data from the underlying dataset
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
|
||||
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
|
||||
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
|
||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
|
||||
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
|
||||
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
|
||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
return {
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
"labels_chosen": 1.0,
|
||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||
"labels_rejected": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", "train_detail"
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1
|
||||
if not cfg.reward_model
|
||||
else cfg.sequence_len,
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
strategy = BTChatTemplateStrategy(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
27
src/axolotl/prompt_strategies/bradley_terry/llama3.py
Normal file
27
src/axolotl/prompt_strategies/bradley_terry/llama3.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
|
||||
"""
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
prompt = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -403,6 +403,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
prompter_params = {
|
||||
|
||||
@@ -102,7 +102,8 @@ def train(
|
||||
model, peft_config = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
model.generation_config.do_sample = True
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
|
||||
@@ -551,6 +551,7 @@ class AxolotlInputConfig(
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
@@ -856,6 +857,17 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_reward_model_pad(cls, data):
|
||||
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using reward_model"
|
||||
)
|
||||
if data.get("pad_to_sequence_len") is None:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
|
||||
@@ -19,6 +19,7 @@ from transformers import PreTrainedTokenizerBase
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -459,7 +460,7 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
@@ -609,7 +610,20 @@ def get_dataset_wrapper(
|
||||
)
|
||||
elif cfg.skip_prepare_dataset:
|
||||
dataset_wrapper = dataset
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
elif ds_strategy := config_dataset.type.startswith(
|
||||
"bradley_terry"
|
||||
) and bradley_terry_load(
|
||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
)
|
||||
elif ds_strategy := load(
|
||||
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
||||
):
|
||||
if isinstance(ds_strategy, DatasetWrappingStrategy):
|
||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||
else:
|
||||
|
||||
@@ -306,7 +306,11 @@ def process_pretraining_datasets_for_packing(
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
|
||||
if (
|
||||
not cfg.total_num_tokens
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.reward_model
|
||||
):
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
@@ -323,6 +327,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
not skip_estimates
|
||||
and not cfg.total_supervised_tokens
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.reward_model
|
||||
):
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
|
||||
74
tests/e2e/test_reward_model_llama.py
Normal file
74
tests/e2e/test_reward_model_llama.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
E2E tests for reward model lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestRewardModelLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama reward models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_rm_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"model_type": "AutoModelForSequenceClassification",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"chat_template": "alpaca",
|
||||
"reward_model": True,
|
||||
"sequence_len": 1024,
|
||||
"pad_to_sequence_len": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "bradley_terry.chat_template",
|
||||
},
|
||||
],
|
||||
"remove_unused_columns": False,
|
||||
"max_steps": 10,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"gradient_checkpointing": True,
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
Reference in New Issue
Block a user