From 4f7c04bae0092b7249dfa86759b61f2fe1421a34 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 8 Aug 2023 03:16:00 -0400 Subject: [PATCH] more fixes and optimizations --- scripts/finetune.py | 8 ++++-- src/axolotl/utils/data.py | 49 ++++++++++++++++++++++++++++++++- src/axolotl/utils/dataloader.py | 9 ++++++ src/axolotl/utils/trainer.py | 12 ++++---- 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 488d23c11..329da1b22 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -253,9 +253,7 @@ def train( train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset ) - - train_dataset.cleanup_cache_files() - eval_dataset.cleanup_cache_files() + accelerator.wait_for_everyone() total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) if cfg.debug or "debug" in kwargs: @@ -382,6 +380,10 @@ def train( # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time + train_dataset.cleanup_cache_files() + if eval_dataset: + eval_dataset.cleanup_cache_files() + if __name__ == "__main__": fire.Fire(train) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6661f8db9..b2a147dab 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,5 +1,6 @@ """Module containing data utilities""" import functools +import hashlib import itertools import logging from hashlib import md5 @@ -7,6 +8,7 @@ from pathlib import Path from typing import List, Tuple, Union import torch +from accelerate import Accelerator from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -37,6 +39,7 @@ from axolotl.prompters import ( ) LOG = logging.getLogger("axolotl") +accelerator = Accelerator() def load_tokenized_prepared_datasets( @@ -416,7 +419,51 @@ def load_prepare_datasets( ) if cfg.val_set_size: - dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache + to_hash_train = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "train" + + "|" + + cfg.seed + ) + to_hash_test = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "test" + + "|" + + cfg.seed + ) + train_fingerprint = hashlib.md5( + to_hash_train.encode(), usedforsecurity=False + ).hexdigest() + test_fingerprint = hashlib.md5( + to_hash_test.encode(), usedforsecurity=False + ).hexdigest() + + if accelerator.is_local_main_process: + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + accelerator.wait_for_everyone() + if not accelerator.is_local_main_process: + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + accelerator.wait_for_everyone() + train_dataset = dataset["train"] eval_dataset = dataset["test"] else: diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 167f3957a..78ad6c150 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -263,6 +263,15 @@ class MultipackDistributedDataloader: def __len__(self): # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get # the same share of total tokens + # if not self.eff_total_used: + # batches, _ = self.generate_batches(set_stats=True) + # LOG.info( + # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + # f"actual packing efficiency: {self.efficiency()}" + # ) + return max(1, self._len_est()) + + def len_w_stats(self): if not self.eff_total_used: batches, _ = self.generate_batches(set_stats=True) LOG.info( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 51f41168c..1f7992b0a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -175,7 +175,7 @@ class AxolotlTrainer(Trainer): # If set to True, the dataloader prepared is only iterated through on the # main process and then the batches are split and broadcast to each process self.accelerator.dispatch_batches = True - return self.accelerator.prepare( + return self.accelerator.prepare_data_loader( MultipackDistributedDataloader( self.train_dataset, batch_size=self._train_batch_size, @@ -200,7 +200,7 @@ class AxolotlTrainer(Trainer): # If set to True, the datalaoder prepared is only iterated through on the # main process and then the batches are split and broadcast to each process self.accelerator.dispatch_batches = True - return self.accelerator.prepare( + return self.accelerator.prepare_data_loader( MultipackDistributedDataloader( eval_dataset, batch_size=self.args.eval_batch_size, @@ -305,6 +305,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): / cfg.sample_packing_eff_est / 2048 // cfg.batch_size + // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 ) @@ -327,17 +328,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, - device_count=int(os.environ.get("WORLD_SIZE", 1)), ) - data_loader_len = len(data_loader) + data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency() LOG.info(f"data_loader_len: {data_loader_len}") total_num_steps = int( - math.ceil( + math.floor( data_loader_len * cfg.micro_batch_size * cfg.num_epochs - / cfg.batch_size + // cfg.batch_size ) ) LOG.info(