From 992ea517b7b3f4618598f4947da2cc6359d9c4aa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Sep 2024 12:01:38 -0700 Subject: [PATCH] setup precision config for bf16 --- src/axolotl/core/trainer_builder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 52ac22b7d..160fee101 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -484,6 +484,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): CommunicationDType, DDPShampooConfig, FSDPShampooConfig, + PrecisionConfig, SGDGraftingConfig, ) from distributed_shampoo.utils.shampoo_fsdp_utils import ( @@ -551,11 +552,23 @@ class AxolotlTrainer(SchedulerMixin, Trainer): num_trainers_per_group=self.args.world_size, communicate_params=False, ) + + precision_config = None + if self.args.bf16: + precision_config = PrecisionConfig( + computation_dtype=torch.bfloat16, + factor_matrix_dtype=torch.bfloat16, + inv_factor_matrix_dtype=torch.bfloat16, + filtered_grad_dtype=torch.bfloat16, + momentum_dtype=torch.bfloat16, + grafting_state_dtype=torch.bfloat16, + ) self.optimizer = ( # pylint: disable=attribute-defined-outside-init DistributedShampoo( optimizer_grouped_parameters, grafting_config=grafting_config, distributed_config=distributed_config, + precision_config=precision_config, **optim_args, ) )