diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index f3be9c2f4..216ff0110 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f del model del tokenizer + cleanup_distributed() + return all_metrics diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a6040ebaa..56e378c7f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -157,6 +158,8 @@ def setup_signal_handler( _model.save_pretrained( cfg.output_dir, safe_serialization=safe_serialization ) + + cleanup_distributed() sys.exit(0) _model_weakref = weakref.ref(model) @@ -478,7 +481,7 @@ def train( Returns: Tuple of (model, tokenizer) after training """ - # Setup model, tokenizer, (causal or RLHF) trainer etc. + # Setup model, tokenizer, (causal or RLHF) trainer, etc. ( trainer, model, @@ -487,34 +490,25 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) - # Determine if we need to resume from a checkpoint - resume_from_checkpoint = determine_resume_checkpoint(cfg) - - # Configuration for saving - safe_serialization = cfg.save_safetensors is True - # Handle untrained tokens if configured + safe_serialization = cfg.save_safetensors is True train_dataset = dataset_meta.train_dataset handle_untrained_tokens_fix( cfg, model, tokenizer, train_dataset, safe_serialization ) - # Save initial configs + # Additional setup save_initial_configs(cfg, tokenizer, model, peft_config, processor) - - # Set up signal handler for graceful termination setup_signal_handler(cfg, model, safe_serialization) - - # Set up badges and config info for model card setup_model_card(cfg) # Execute the training + resume_from_checkpoint = determine_resume_checkpoint(cfg) execute_training(cfg, trainer, resume_from_checkpoint) - # Save the trained model + # Save the trained model and cleanup save_trained_model(cfg, trainer, model, safe_serialization) - - # Create model card create_model_card(cfg, trainer) + cleanup_distributed() return model, tokenizer, trainer diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 7c671abfe..acd2566d0 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -71,8 +71,8 @@ def barrier(): def is_main_process(): """ - Check if the current process is the main process. - If not in distributed mode, always return True. + Check if the current process is the main process. If not in distributed mode, + always return `True`. """ if not is_distributed(): return True @@ -87,6 +87,18 @@ def get_world_size(): return int(os.getenv("WORLD_SIZE", "1")) +def cleanup_distributed(): + """ + Destroy process group if torch distributed is initialized. Called in training early + termination or when training successfully completes. + """ + # Ensure that all operations are completed before destroying the process group + torch.cuda.synchronize() + # Destroy the process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + @contextmanager def zero_only(): """