diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index 44f9c7e49..d61c72a37 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -72,4 +72,5 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4e8b36905..1a073ca04 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1846,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ) if self.cfg.fsdp: ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: + ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3e8d50f5e..4f47d59bf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1102,9 +1102,20 @@ def load_lora(model, cfg, inference=False, config_only=False): def ensure_dtype(model, dtype=torch.bfloat16): for name, module in model.named_modules(): + weight_mismatch = False + bias_mismatch = False try: - if module.weight.dtype != dtype: - print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") - module.to(dtype) + weight_mismatch = module.weight.dtype != dtype except AttributeError: pass + try: + bias_mismatch = module.bias.dtype != dtype + except AttributeError: + pass + + if weight_mismatch: + print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if bias_mismatch: + print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + if weight_mismatch or bias_mismatch: + module.to(dtype) diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py new file mode 100644 index 000000000..2513be69e --- /dev/null +++ b/tests/e2e/multigpu/test_qwen2.py @@ -0,0 +1,98 @@ +""" +E2E tests for multigpu qwen2 +""" + +import logging +import os +import unittest +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async + +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMultiGPUQwen2(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora_fsdp_dpo(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2-1.5B", + "load_in_4bit": True, + "rl": "dpo", + "chat_template": "chatml", + "sequence_len": 2048, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 100, + "warmup_steps": 20, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "tf32": True, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": { + "use_reentrant": False, + }, + "fsdp": [ + "full_shard", + "auto_wrap", + ], + "fsdp_config": { + "fsdp_limit_all_gathers": True, + "fsdp_offload_params": False, + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + "fsdp_cpu_ram_efficient_loading": False, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + }, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + )