Accelerate 1.8.1 and BNB 0.46.0 update (#2815)

* update accelerate to v1.8.0

* update bnb also

* fix multigpu ci timeout

* fix test set size

* use latest accelerate 1.8.1

* disable default dtype
This commit is contained in:
Wing Lian
2025-06-28 15:29:19 -04:00
committed by GitHub
parent a1a740608d
commit 81893c775c
11 changed files with 32 additions and 7 deletions

View File

@@ -223,8 +223,9 @@ def execute_training(
)
LOG.info("Starting trainer...")
if cfg.bf16:
torch.set_default_dtype(torch.bfloat16)
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)