Compare commits
1 Commits
keep_in_me
...
fp8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8836986a92 |
@@ -483,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
if self.cfg.fp8:
|
||||
training_arguments_kwargs["fp16"] = False
|
||||
training_arguments_kwargs["bf16"] = False
|
||||
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
@@ -70,7 +70,9 @@ def normalize_config(cfg):
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
if cfg.fp8:
|
||||
cfg.torch_dtype = torch.bfloat16
|
||||
elif cfg.bf16 or cfg.bfloat16:
|
||||
cfg.torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
cfg.torch_dtype = torch.float16
|
||||
|
||||
@@ -268,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
if cfg.fp8:
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
|
||||
Reference in New Issue
Block a user