setup precision config for bf16

This commit is contained in:
Wing Lian
2024-09-18 12:01:38 -07:00
parent beaee36191
commit 992ea517b7

View File

@@ -484,6 +484,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
PrecisionConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
@@ -551,11 +552,23 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
precision_config = None
if self.args.bf16:
precision_config = PrecisionConfig(
computation_dtype=torch.bfloat16,
factor_matrix_dtype=torch.bfloat16,
inv_factor_matrix_dtype=torch.bfloat16,
filtered_grad_dtype=torch.bfloat16,
momentum_dtype=torch.bfloat16,
grafting_state_dtype=torch.bfloat16,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
precision_config=precision_config,
**optim_args,
)
)