feat: Add GDPO Support (#3353)
* gdpo support - test left * lint * fixxes for vllm serv * test advantages * docss * lint * lint = * gdpo simple + lint * lint nit * example * lint * trl 0.27.0 * blocklist * test assert rmv * add validation check for GDPO + sum_then_normalize --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
chat_template: llama3
|
||||
|
||||
rl: gdpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 128
|
||||
num_generations: 2
|
||||
temperature: 0.7
|
||||
top_p: 0.95
|
||||
|
||||
use_vllm: false
|
||||
|
||||
|
||||
multi_objective_aggregation: normalize_then_sum
|
||||
|
||||
reward_funcs:
|
||||
- rwd.format_reward
|
||||
- rwd.correctness_reward
|
||||
reward_weights: [1.0, 2.0]
|
||||
|
||||
log_completions: true
|
||||
num_completions_to_print: 3
|
||||
scale_rewards: true
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
split: train[:1000]
|
||||
type: rwd.gsm8k_transform
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/llama3-gdpo-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
flash_attention: true
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
save_safetensors: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
|
||||
seed: 42
|
||||
Reference in New Issue
Block a user