setup precision config for bf16
This commit is contained in:
@@ -484,6 +484,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
CommunicationDType,
|
CommunicationDType,
|
||||||
DDPShampooConfig,
|
DDPShampooConfig,
|
||||||
FSDPShampooConfig,
|
FSDPShampooConfig,
|
||||||
|
PrecisionConfig,
|
||||||
SGDGraftingConfig,
|
SGDGraftingConfig,
|
||||||
)
|
)
|
||||||
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
||||||
@@ -551,11 +552,23 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
num_trainers_per_group=self.args.world_size,
|
num_trainers_per_group=self.args.world_size,
|
||||||
communicate_params=False,
|
communicate_params=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
precision_config = None
|
||||||
|
if self.args.bf16:
|
||||||
|
precision_config = PrecisionConfig(
|
||||||
|
computation_dtype=torch.bfloat16,
|
||||||
|
factor_matrix_dtype=torch.bfloat16,
|
||||||
|
inv_factor_matrix_dtype=torch.bfloat16,
|
||||||
|
filtered_grad_dtype=torch.bfloat16,
|
||||||
|
momentum_dtype=torch.bfloat16,
|
||||||
|
grafting_state_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
DistributedShampoo(
|
DistributedShampoo(
|
||||||
optimizer_grouped_parameters,
|
optimizer_grouped_parameters,
|
||||||
grafting_config=grafting_config,
|
grafting_config=grafting_config,
|
||||||
distributed_config=distributed_config,
|
distributed_config=distributed_config,
|
||||||
|
precision_config=precision_config,
|
||||||
**optim_args,
|
**optim_args,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user