more fixes and optimizations
This commit is contained in:
@@ -253,9 +253,7 @@ def train(
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
eval_dataset.cleanup_cache_files()
|
||||
accelerator.wait_for_everyone()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
@@ -382,6 +380,10 @@ def train(
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
if eval_dataset:
|
||||
eval_dataset.cleanup_cache_files()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
|
||||
Reference in New Issue
Block a user