Compare commits
2 Commits
nd_paralle
...
32ce167404
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32ce167404 | ||
|
|
1c4cc639f5 |
@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
|
|||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
|||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
return all_metrics
|
return all_metrics
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
|||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
@@ -157,6 +158,8 @@ def setup_signal_handler(
|
|||||||
_model.save_pretrained(
|
_model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cleanup_distributed()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
_model_weakref = weakref.ref(model)
|
_model_weakref = weakref.ref(model)
|
||||||
@@ -478,7 +481,7 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
model,
|
model,
|
||||||
@@ -487,34 +490,25 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
||||||
|
|
||||||
# Configuration for saving
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
handle_untrained_tokens_fix(
|
handle_untrained_tokens_fix(
|
||||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save initial configs
|
# Additional setup
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|
||||||
# Set up badges and config info for model card
|
|
||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
# Save the trained model
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|
||||||
# Create model card
|
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
return model, tokenizer, trainer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ def barrier():
|
|||||||
|
|
||||||
def is_main_process():
|
def is_main_process():
|
||||||
"""
|
"""
|
||||||
Check if the current process is the main process.
|
Check if the current process is the main process. If not in distributed mode,
|
||||||
If not in distributed mode, always return True.
|
always return `True`.
|
||||||
"""
|
"""
|
||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
return True
|
return True
|
||||||
@@ -87,6 +87,18 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_distributed():
|
||||||
|
"""
|
||||||
|
Destroy process group if torch distributed is initialized. Called in training early
|
||||||
|
termination or when training successfully completes.
|
||||||
|
"""
|
||||||
|
# Ensure that all operations are completed before destroying the process group
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Destroy the process group
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_only():
|
def zero_only():
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user