review comments, docstrings

This commit is contained in:
Dan Saunders
2025-03-19 17:35:09 +00:00
parent a26985c53c
commit 2f0b4626b9
5 changed files with 35 additions and 35 deletions

View File

@@ -34,7 +34,7 @@ resize_token_embeddings_to_32x:
shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init:
random_init_weights:
# (Internal use only)
# Used to identify which the model is based on

View File

@@ -1,6 +1,10 @@
"""Ring attention group registration and flash attention patching."""
"""
Ring attention group registration and flash attention patching.
from typing import Any
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch.distributed as dist
from accelerate.logging import get_logger
@@ -14,11 +18,23 @@ LOG = get_logger(__name__)
RING_ATTN_GROUP = None
def get_ring_attn_group() -> Any:
def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: Any):
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import contextlib
import importlib
import inspect
import os
@@ -166,35 +165,11 @@ def setup_signal_handler(
)
def train_context_manager(
flash_optimum: bool = False,
) -> contextlib.AbstractContextManager:
"""
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
Args:
flash_optimum: Whether to enable efficient backends for SDP attention.
Returns:
Context manager for temporarily enabling efficient backends for SDP attention
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
"""
if flash_optimum:
return torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
return contextlib.nullcontext()
def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate backend configurations.
Execute the training process with appropriate SDP kernel configurations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -202,8 +177,15 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
with context_manager:
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -922,7 +922,7 @@ class ModelLoader:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified
if self.cfg.random_init:
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,

View File

@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"dataset_exact_deduplication": True,
@@ -280,9 +282,9 @@ class TestDeduplicateNonRL(unittest.TestCase):
"batch_size": 10,
"micro_batch_size": 10,
"num_epochs": 1,
"sequence_parallel_degree": 1,
}
)
normalize_config(self.cfg_1)
def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""