diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml new file mode 100644 index 000000000..64c3e7629 --- /dev/null +++ b/examples/qwen2/dpo.yaml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-0.5B + +strict: false + +chat_template: qwen_25 +rl: dpo +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/dpo-out + +sequence_len: 2048 +sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1cad1a8c3..14690580d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1038,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) + @staticmethod def tokenize_row( - self, features, processing_class, max_prompt_length, max_completion_length, add_special_tokens, ) -> Dict: - res = super().tokenize_row( + res = DPOTrainer.tokenize_row( features, processing_class, max_prompt_length, max_completion_length, add_special_tokens, ) - if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None: + # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen + if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: for key in res.keys(): res[key] = res[key][1:] + + if processing_class.bos_token and processing_class.bos_token_id is not None: + # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs + if res["chosen_input_ids"][0] == processing_class.bos_token_id: + res["chosen_input_ids"] = res["chosen_input_ids"][1:] + res["chosen_labels"] = res["chosen_labels"][1:] + res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] + if res["rejected_input_ids"][0] == processing_class.bos_token_id: + res["rejected_input_ids"] = res["rejected_input_ids"][1:] + res["rejected_labels"] = res["rejected_labels"][1:] + res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] + return res def training_step( diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py new file mode 100644 index 000000000..7a343f4d3 --- /dev/null +++ b/tests/e2e/test_qwen.py @@ -0,0 +1,85 @@ +""" +E2E tests for qwen +""" + +import logging +import os +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.tests.qwen") +os.environ["WANDB_DISABLED"] = "true" + + +class TestE2eQwen: + """ + Test cases for qwen models + """ + + @pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"]) + def test_dpo(self, base_model, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": base_model, + "rl": "dpo", + "chat_template": "qwen_25", + "sequence_len": 2048, + "val_set_size": 0.0, + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "split": "train", + "type": "chat_template.default", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + }, + ], + "num_epochs": 1, + "max_steps": 5, + "warmup_steps": 20, + "micro_batch_size": 2, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "tf32": True, + "gradient_checkpointing": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + )