Add ORPO example and e2e test (#1572)

* add example for mistral orpo

* sample_packing: false for orpo

* go to load_dataset (since load_rl_datasets require a transfom_fn, which only dpo uses currently)
This commit is contained in:
Motoki Wu
2024-04-27 09:07:06 -07:00
committed by GitHub
parent 68601ec6ad
commit 98c25e15cb
4 changed files with 131 additions and 1 deletions

1
.gitignore vendored
View File

@@ -133,6 +133,7 @@ venv/
ENV/
env.bak/
venv.bak/
venv3.10/
# Spyder project settings
.spyderproject

View File

@@ -49,7 +49,7 @@ remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
type: chat_template.argilla
```
#### Using local dataset files

View File

@@ -0,0 +1,82 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -158,3 +158,50 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()