use custom distributed checks

This commit is contained in:
Wing Lian
2023-08-08 13:35:04 -04:00
parent 035b3c760c
commit 1b8747e319
3 changed files with 54 additions and 34 deletions

View File

@@ -14,13 +14,13 @@ import torch
import yaml import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from accelerate import Accelerator
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
@@ -173,7 +173,6 @@ def train(
prepare_ds_only: bool = False, prepare_ds_only: bool = False,
**kwargs, **kwargs,
): ):
accelerator = Accelerator()
if Path(config).is_dir(): if Path(config).is_dir():
config = choose_config(config) config = choose_config(config)
@@ -243,17 +242,17 @@ def train(
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
if accelerator.is_local_main_process: if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache # process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
accelerator.wait_for_everyone() barrier()
if not accelerator.is_local_main_process: if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
accelerator.wait_for_everyone() barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
@@ -366,23 +365,17 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
with model.summon_full_params(): trainer.save_model(cfg.output_dir)
model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(model),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
train_dataset.cleanup_cache_files() train_dataset.cleanup_cache_files()
if eval_dataset: if eval_dataset:
eval_dataset.cleanup_cache_files() eval_dataset.cleanup_cache_files()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -8,7 +8,6 @@ from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
from accelerate import Accelerator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -37,9 +36,9 @@ from axolotl.prompters import (
ShareGPTPrompter, ShareGPTPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
accelerator = Accelerator()
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
@@ -112,16 +111,14 @@ def load_tokenized_prepared_datasets(
local_path = Path(d.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
try: # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_from_disk(d.path) ds = load_dataset(
except FileNotFoundError: d.path,
ds = load_dataset( name=d.name,
d.path, data_files=d.data_files,
name=d.name, streaming=False,
data_files=d.data_files, split=None,
streaming=False, )
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds = load_dataset( ds = load_dataset(
"json", "json",
@@ -445,7 +442,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False to_hash_test.encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
if accelerator.is_local_main_process: if is_main_process():
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,
@@ -453,8 +450,8 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint, train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint, test_new_fingerprint=test_fingerprint,
) )
accelerator.wait_for_everyone() barrier()
if not accelerator.is_local_main_process: if not is_main_process():
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,
@@ -462,7 +459,7 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint, train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint, test_new_fingerprint=test_fingerprint,
) )
accelerator.wait_for_everyone() barrier()
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]

View File

@@ -0,0 +1,30 @@
"""
utility helpers for distributed checks
"""
import torch.distributed as dist
def is_distributed():
"""
Check if distributed training is initialized.
"""
return dist.is_available() and dist.is_initialized()
def barrier():
"""
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
if is_distributed():
dist.barrier()
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
return dist.get_rank() == 0