add flash attn context for efficient training and attempt setting model to train mode:

This commit is contained in:
Wing Lian
2023-05-27 18:12:12 -04:00
parent 1edc30c786
commit 8792199799

View File

@@ -252,6 +252,24 @@ def train(
model.save_pretrained(cfg.output_dir)
return
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
try:
model.train()
except:
pass
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
@@ -297,7 +315,11 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")