review comments, docstrings

This commit is contained in:
Dan Saunders
2025-03-19 17:35:09 +00:00
parent a26985c53c
commit 2f0b4626b9
5 changed files with 35 additions and 35 deletions

View File

@@ -34,7 +34,7 @@ resize_token_embeddings_to_32x:
shrink_embeddings: shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for # Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes. # pre-training a model from scratch or debugging purposes.
random_init: random_init_weights:
# (Internal use only) # (Internal use only)
# Used to identify which the model is based on # Used to identify which the model is based on

View File

@@ -1,6 +1,10 @@
"""Ring attention group registration and flash attention patching.""" """
Ring attention group registration and flash attention patching.
from typing import Any Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -14,11 +18,23 @@ LOG = get_logger(__name__)
RING_ATTN_GROUP = None RING_ATTN_GROUP = None
def get_ring_attn_group() -> Any: def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: Any): def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import contextlib
import importlib import importlib
import inspect import inspect
import os import os
@@ -166,35 +165,11 @@ def setup_signal_handler(
) )
def train_context_manager(
flash_optimum: bool = False,
) -> contextlib.AbstractContextManager:
"""
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
Args:
flash_optimum: Whether to enable efficient backends for SDP attention.
Returns:
Context manager for temporarily enabling efficient backends for SDP attention
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
"""
if flash_optimum:
return torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
return contextlib.nullcontext()
def execute_training( def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
): ):
""" """
Execute the training process with appropriate backend configurations. Execute the training process with appropriate SDP kernel configurations.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -202,8 +177,15 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable. resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
""" """
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum) if cfg.flash_optimum:
with context_manager: with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -922,7 +922,7 @@ class ModelLoader:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified # Load model with random initialization if specified
if self.cfg.random_init: if self.cfg.random_init_weights:
# AutoModel classes support the from_config method # AutoModel classes support the from_config method
if self.auto_model_loader in [ if self.auto_model_loader in [
AutoModelForCausalLM, AutoModelForCausalLM,

View File

@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault( self.cfg_1 = DictDefault(
{ {
"base_model": "huggyllama/llama-7b",
"tokenizer_config": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024, "sequence_len": 1024,
"dataset_exact_deduplication": True, "dataset_exact_deduplication": True,
@@ -280,9 +282,9 @@ class TestDeduplicateNonRL(unittest.TestCase):
"batch_size": 10, "batch_size": 10,
"micro_batch_size": 10, "micro_batch_size": 10,
"num_epochs": 1, "num_epochs": 1,
"sequence_parallel_degree": 1,
} }
) )
normalize_config(self.cfg_1)
def test_prepare_dataset_with_deduplication_train(self): def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication.""" """Verify that prepare_dataset function processes the dataset correctly with deduplication."""