use custom distributed checks

This commit is contained in:
Wing Lian
2023-08-08 13:35:04 -04:00
parent 035b3c760c
commit 1b8747e319
3 changed files with 54 additions and 34 deletions

View File

@@ -14,13 +14,13 @@ import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate import Accelerator
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import (
@@ -173,7 +173,6 @@ def train(
prepare_ds_only: bool = False,
**kwargs,
):
accelerator = Accelerator()
if Path(config).is_dir():
config = choose_config(config)
@@ -243,17 +242,17 @@ def train(
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if accelerator.is_local_main_process:
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
accelerator.wait_for_everyone()
if not accelerator.is_local_main_process:
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
accelerator.wait_for_everyone()
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
@@ -366,23 +365,17 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
with model.summon_full_params():
model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(model),
)
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
train_dataset.cleanup_cache_files()
if eval_dataset:
eval_dataset.cleanup_cache_files()
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
train_dataset.cleanup_cache_files()
if eval_dataset:
eval_dataset.cleanup_cache_files()
if __name__ == "__main__":