use custom distributed checks
This commit is contained in:
@@ -14,13 +14,13 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from accelerate import Accelerator
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
@@ -173,7 +173,6 @@ def train(
|
|||||||
prepare_ds_only: bool = False,
|
prepare_ds_only: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
accelerator = Accelerator()
|
|
||||||
if Path(config).is_dir():
|
if Path(config).is_dir():
|
||||||
config = choose_config(config)
|
config = choose_config(config)
|
||||||
|
|
||||||
@@ -243,17 +242,17 @@ def train(
|
|||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if is_main_process():
|
||||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
barrier()
|
||||||
if not accelerator.is_local_main_process:
|
if not is_main_process():
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
barrier()
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
@@ -366,23 +365,17 @@ def train(
|
|||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
with model.summon_full_params():
|
trainer.save_model(cfg.output_dir)
|
||||||
model.save_pretrained(
|
|
||||||
cfg.output_dir,
|
|
||||||
is_main_process=trainer.accelerator.is_main_process,
|
|
||||||
save_function=trainer.accelerator.save,
|
|
||||||
state_dict=trainer.accelerator.get_state_dict(model),
|
|
||||||
)
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
train_dataset.cleanup_cache_files()
|
train_dataset.cleanup_cache_files()
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset.cleanup_cache_files()
|
eval_dataset.cleanup_cache_files()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from pathlib import Path
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -37,9 +36,9 @@ from axolotl.prompters import (
|
|||||||
ShareGPTPrompter,
|
ShareGPTPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
accelerator = Accelerator()
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
@@ -112,16 +111,14 @@ def load_tokenized_prepared_datasets(
|
|||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
try:
|
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||||
ds = load_from_disk(d.path)
|
ds = load_dataset(
|
||||||
except FileNotFoundError:
|
d.path,
|
||||||
ds = load_dataset(
|
name=d.name,
|
||||||
d.path,
|
data_files=d.data_files,
|
||||||
name=d.name,
|
streaming=False,
|
||||||
data_files=d.data_files,
|
split=None,
|
||||||
streaming=False,
|
)
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json",
|
"json",
|
||||||
@@ -445,7 +442,7 @@ def load_prepare_datasets(
|
|||||||
to_hash_test.encode(), usedforsecurity=False
|
to_hash_test.encode(), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if is_main_process():
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@@ -453,8 +450,8 @@ def load_prepare_datasets(
|
|||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
barrier()
|
||||||
if not accelerator.is_local_main_process:
|
if not is_main_process():
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@@ -462,7 +459,7 @@ def load_prepare_datasets(
|
|||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
barrier()
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
|||||||
30
src/axolotl/utils/distributed.py
Normal file
30
src/axolotl/utils/distributed.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
utility helpers for distributed checks
|
||||||
|
"""
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed():
|
||||||
|
"""
|
||||||
|
Check if distributed training is initialized.
|
||||||
|
"""
|
||||||
|
return dist.is_available() and dist.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def barrier():
|
||||||
|
"""
|
||||||
|
Acts as a barrier to wait for all processes. This ensures that all processes
|
||||||
|
reach the barrier before proceeding further.
|
||||||
|
"""
|
||||||
|
if is_distributed():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
"""
|
||||||
|
Check if the current process is the main process.
|
||||||
|
If not in distributed mode, always return True.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return True
|
||||||
|
return dist.get_rank() == 0
|
||||||
Reference in New Issue
Block a user