Compare commits

...

1 Commits
dft ... fp8

Author SHA1 Message Date
Wing Lian
8836986a92 support for fp8 2023-11-10 02:35:19 -05:00
3 changed files with 9 additions and 1 deletions

View File

@@ -483,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
if self.cfg.fp8:
training_arguments_kwargs["fp16"] = False
training_arguments_kwargs["bf16"] = False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps

View File

@@ -70,7 +70,9 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16 or cfg.bfloat16:
if cfg.fp8:
cfg.torch_dtype = torch.bfloat16
elif cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16

View File

@@ -268,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset