diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index b8b2bded0..48701c87a 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -52,6 +52,26 @@ datasets: type: chat_template.argilla ``` + +#### KTO + +```yaml +rl: kto +rl_beta: 0.5 +kto_desirable_weight: 0.2 + +remove_unused_columns: false + +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto + type: llama3.ultra + split: train + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +``` + #### Using local dataset files ```yaml datasets: diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml new file mode 100644 index 000000000..a876d8fd7 --- /dev/null +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -0,0 +1,75 @@ +base_model: meta-llama/Llama-3.2-1B + +load_in_8bit: false +load_in_4bit: true +strict: false + +rl: kto +rl_beta: 0.5 +kto_desirable_weight: 0.2 + +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto + type: llama3.ultra + split: train +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/qlora-out + +remove_unused_columns: false + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: false # not supported with kto +eval_sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 64 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 20 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 24ea62c77..3671e1bb9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1475,6 +1475,27 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_kto_config(cls, data): + if data.get("rl") == "kto": + if data.get("sample_packing") or data.get("eval_sample_packing"): + raise ValueError("sample_packing is not supported with kto") + + if data.get("remove_unused_columns") is not False: + raise ValueError("Set `remove_unused_columns: False` when using kto") + + if data.get("gradient_checkpointing") and not ( + data.get("gradient_checkpointing_kwargs") + and isinstance(data.get("gradient_checkpointing_kwargs"), dict) + and data["gradient_checkpointing_kwargs"].get("use_reentrant") + ): + raise ValueError( + "Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled" + ) + + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options"""