Compare commits
10 Commits
432b17eee1
...
6dc0f4dac6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dc0f4dac6 | ||
|
|
1fceaa20e3 | ||
|
|
7ee7b4c493 | ||
|
|
d2e51406a1 | ||
|
|
5d55c08086 | ||
|
|
cc2815a3cc | ||
|
|
3b648f6bbe | ||
|
|
5294fe5a99 | ||
|
|
4b1273ae1e | ||
|
|
394806ab30 |
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.0
|
||||
trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
|
||||
@@ -1926,16 +1926,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_kwargs["max_target_length"] = None
|
||||
if self.cfg.max_prompt_len is not None:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs
|
||||
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1955,27 +1971,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def build(self, total_num_steps):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
8
test.yml
8
test.yml
@@ -5,14 +5,14 @@ load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
|
||||
type: chatml.ultra
|
||||
split: train
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.2
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
43
test2.yml
Normal file
43
test2.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
base_model: JackFram/llama-68m
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
datasets:
|
||||
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
|
||||
type: chatml.ultra
|
||||
split: train
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 1024
|
||||
|
||||
adapter: lora
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
lora_target_linear: true
|
||||
|
||||
rl: dpo
|
||||
dpo_use_weighting: true
|
||||
|
||||
wandb_project: check_dpotrainer
|
||||
wandb_entity: axolotl-ai
|
||||
wandb_watch:
|
||||
wandb_name: baseline/dpo_base/dpo_use_weighting
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
num_epochs: 1
|
||||
micro_batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 0.00001
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
max_steps": 20
|
||||
save_steps: 10
|
||||
warmup_steps: 5
|
||||
gradient_checkpointing: True
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
#special_tokens:
|
||||
# pad_token: <|end_of_text|>
|
||||
|
||||
Reference in New Issue
Block a user