add flash attn context for efficient training and attempt setting model to train mode:
This commit is contained in:
@@ -252,6 +252,24 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if cfg.debug:
|
||||||
|
logging.info("check_dataset_labels...")
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prepare_ds_only:
|
||||||
|
logging.info("Finished preparing dataset. Exiting...")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.train()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||||
|
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
@@ -297,7 +315,11 @@ def train(
|
|||||||
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
if cfg.flash_optimum:
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user