From 2f0b4626b9930d2b74914140a3f6aac6e4cfbbb2 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 19 Mar 2025 17:35:09 +0000 Subject: [PATCH] review comments, docstrings --- docs/config.qmd | 2 +- .../monkeypatch/attention/ring_attn.py | 24 ++++++++++-- src/axolotl/train.py | 38 +++++-------------- src/axolotl/utils/models.py | 2 +- tests/test_exact_deduplication.py | 4 +- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index a68afde04..9946b5865 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -34,7 +34,7 @@ resize_token_embeddings_to_32x: shrink_embeddings: # Whether to load the model with randomly initialized weights. Useful for # pre-training a model from scratch or debugging purposes. -random_init: +random_init_weights: # (Internal use only) # Used to identify which the model is based on diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index fe333ad32..2cde5b98d 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -1,6 +1,10 @@ -"""Ring attention group registration and flash attention patching.""" +""" +Ring attention group registration and flash attention patching. -from typing import Any +Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) +package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in +their sequence parallel version of Flash Attention 2. +""" import torch.distributed as dist from accelerate.logging import get_logger @@ -14,11 +18,23 @@ LOG = get_logger(__name__) RING_ATTN_GROUP = None -def get_ring_attn_group() -> Any: +def get_ring_attn_group() -> dist.ProcessGroup: + """ + Getter for ring attention group on this rank. + + Returns: + The process group for ring attention for this rank. + """ return RING_ATTN_GROUP -def set_ring_attn_group(ring_attn_group: Any): +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup): + """ + Setter for ring attention group on this rank. + + Args: + Process group for ring attention. + """ global RING_ATTN_GROUP # pylint: disable=global-statement RING_ATTN_GROUP = ring_attn_group diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4e6054df1..9ccd2ca0c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,6 +1,5 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" -import contextlib import importlib import inspect import os @@ -166,35 +165,11 @@ def setup_signal_handler( ) -def train_context_manager( - flash_optimum: bool = False, -) -> contextlib.AbstractContextManager: - """ - Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`. - - Args: - flash_optimum: Whether to enable efficient backends for SDP attention. - - Returns: - Context manager for temporarily enabling efficient backends for SDP attention - if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise. - """ - if flash_optimum: - return torch.backends.cuda.sdp_kernel( - # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... - enable_flash=True, - enable_math=True, - enable_mem_efficient=True, - ) - - return contextlib.nullcontext() - - def execute_training( cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None ): """ - Execute the training process with appropriate backend configurations. + Execute the training process with appropriate SDP kernel configurations. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -202,8 +177,15 @@ def execute_training( resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") - context_manager = train_context_manager(flash_optimum=cfg.flash_optimum) - with context_manager: + if cfg.flash_optimum: + with torch.backends.cuda.sdp_kernel( + # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ): + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index afb3e37a0..83f70a022 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -922,7 +922,7 @@ class ModelLoader: self.model_config.text_config = self.text_model_config # Load model with random initialization if specified - if self.cfg.random_init: + if self.cfg.random_init_weights: # AutoModel classes support the from_config method if self.auto_model_loader in [ AutoModelForCausalLM, diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 128d2d05c..d32eb3953 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS from datasets import Dataset from transformers import AutoTokenizer +from axolotl.utils.config import normalize_config from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets @@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase): self.tokenizer.add_special_tokens(SPECIAL_TOKENS) self.cfg_1 = DictDefault( { + "base_model": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b", "sequence_len": 1024, "dataset_exact_deduplication": True, @@ -280,9 +282,9 @@ class TestDeduplicateNonRL(unittest.TestCase): "batch_size": 10, "micro_batch_size": 10, "num_epochs": 1, - "sequence_parallel_degree": 1, } ) + normalize_config(self.cfg_1) def test_prepare_dataset_with_deduplication_train(self): """Verify that prepare_dataset function processes the dataset correctly with deduplication."""