make sure everything stays in the same dtype when using dpo + FSDP (#1559)

This commit is contained in:
Wing Lian
2024-04-22 16:00:05 -04:00
committed by GitHub
parent 60f5ce0569
commit 68601ec6ad
2 changed files with 14 additions and 0 deletions

View File

@@ -54,6 +54,7 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
@@ -1569,6 +1570,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

View File

@@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
setup_quantized_peft_meta_for_training(model)
return model, lora_config
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass