diff --git a/scripts/finetune.py b/scripts/finetune.py index 329da1b22..846127d29 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,13 +14,13 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this -from accelerate import Accelerator from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import ( @@ -173,7 +173,6 @@ def train( prepare_ds_only: bool = False, **kwargs, ): - accelerator = Accelerator() if Path(config).is_dir(): config = choose_config(config) @@ -243,17 +242,17 @@ def train( train_dataset = train_dataset.with_format("torch") eval_dataset = None - if accelerator.is_local_main_process: + if is_main_process(): # process on rank 0 first so it gets cached so other ranks load from cache train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset ) - accelerator.wait_for_everyone() - if not accelerator.is_local_main_process: + barrier() + if not is_main_process(): train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset ) - accelerator.wait_for_everyone() + barrier() total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) if cfg.debug or "debug" in kwargs: @@ -366,23 +365,17 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - with model.summon_full_params(): - model.save_pretrained( - cfg.output_dir, - is_main_process=trainer.accelerator.is_main_process, - save_function=trainer.accelerator.save, - state_dict=trainer.accelerator.get_state_dict(model), - ) + trainer.save_model(cfg.output_dir) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) - # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time - - train_dataset.cleanup_cache_files() - if eval_dataset: - eval_dataset.cleanup_cache_files() + trainer.accelerator.wait_for_everyone() + if trainer.accelerator.is_main_process: + train_dataset.cleanup_cache_files() + if eval_dataset: + eval_dataset.cleanup_cache_files() if __name__ == "__main__": diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 3861a8c74..d8053ba15 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import List, Tuple, Union import torch -from accelerate import Accelerator from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -37,9 +36,9 @@ from axolotl.prompters import ( ShareGPTPrompter, SummarizeTLDRPrompter, ) +from axolotl.utils.distributed import barrier, is_main_process LOG = logging.getLogger("axolotl") -accelerator = Accelerator() def load_tokenized_prepared_datasets( @@ -112,16 +111,14 @@ def load_tokenized_prepared_datasets( local_path = Path(d.path) if local_path.exists(): if local_path.is_dir(): - try: - ds = load_from_disk(d.path) - except FileNotFoundError: - ds = load_dataset( - d.path, - name=d.name, - data_files=d.data_files, - streaming=False, - split=None, - ) + # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` + ds = load_dataset( + d.path, + name=d.name, + data_files=d.data_files, + streaming=False, + split=None, + ) elif local_path.is_file(): ds = load_dataset( "json", @@ -445,7 +442,7 @@ def load_prepare_datasets( to_hash_test.encode(), usedforsecurity=False ).hexdigest() - if accelerator.is_local_main_process: + if is_main_process(): dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, @@ -453,8 +450,8 @@ def load_prepare_datasets( train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) - accelerator.wait_for_everyone() - if not accelerator.is_local_main_process: + barrier() + if not is_main_process(): dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, @@ -462,7 +459,7 @@ def load_prepare_datasets( train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) - accelerator.wait_for_everyone() + barrier() train_dataset = dataset["train"] eval_dataset = dataset["test"] diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py new file mode 100644 index 000000000..855c95304 --- /dev/null +++ b/src/axolotl/utils/distributed.py @@ -0,0 +1,30 @@ +""" +utility helpers for distributed checks +""" +import torch.distributed as dist + + +def is_distributed(): + """ + Check if distributed training is initialized. + """ + return dist.is_available() and dist.is_initialized() + + +def barrier(): + """ + Acts as a barrier to wait for all processes. This ensures that all processes + reach the barrier before proceeding further. + """ + if is_distributed(): + dist.barrier() + + +def is_main_process(): + """ + Check if the current process is the main process. + If not in distributed mode, always return True. + """ + if not is_distributed(): + return True + return dist.get_rank() == 0