From 68601ec6ad1cc0e8cb855376586e6eef6a8aa270 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Apr 2024 16:00:05 -0400 Subject: [PATCH] make sure everything stays in the same dtype when using dpo + FSDP (#1559) --- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/models.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fdb081003..6bddb9574 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -54,6 +54,7 @@ from axolotl.utils.collators import ( MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, @@ -1569,6 +1570,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + if self.cfg.fsdp: + ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): dpo_trainer.add_callback(callback) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52d8db047..8537b7e75 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_peft_meta_for_training(model) return model, lora_config + + +def ensure_dtype(model, dtype=torch.bfloat16): + for name, module in model.named_modules(): + try: + if module.weight.dtype != dtype: + print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") + module.to(dtype) + except AttributeError: + pass