review comments, docstrings
This commit is contained in:
@@ -34,7 +34,7 @@ resize_token_embeddings_to_32x:
|
||||
shrink_embeddings:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init:
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Ring attention group registration and flash attention patching."""
|
||||
"""
|
||||
Ring attention group registration and flash attention patching.
|
||||
|
||||
from typing import Any
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
@@ -14,11 +18,23 @@ LOG = get_logger(__name__)
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
def get_ring_attn_group() -> Any:
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""
|
||||
Getter for ring attention group on this rank.
|
||||
|
||||
Returns:
|
||||
The process group for ring attention for this rank.
|
||||
"""
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group: Any):
|
||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||
"""
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
@@ -166,35 +165,11 @@ def setup_signal_handler(
|
||||
)
|
||||
|
||||
|
||||
def train_context_manager(
|
||||
flash_optimum: bool = False,
|
||||
) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
|
||||
|
||||
Args:
|
||||
flash_optimum: Whether to enable efficient backends for SDP attention.
|
||||
|
||||
Returns:
|
||||
Context manager for temporarily enabling efficient backends for SDP attention
|
||||
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
|
||||
"""
|
||||
if flash_optimum:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
)
|
||||
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
Execute the training process with appropriate SDP kernel configurations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -202,8 +177,15 @@ def execute_training(
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
|
||||
with context_manager:
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
|
||||
|
||||
@@ -922,7 +922,7 @@ class ModelLoader:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init:
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
if self.auto_model_loader in [
|
||||
AutoModelForCausalLM,
|
||||
|
||||
@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg_1 = DictDefault(
|
||||
{
|
||||
"base_model": "huggyllama/llama-7b",
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"dataset_exact_deduplication": True,
|
||||
@@ -280,9 +282,9 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"batch_size": 10,
|
||||
"micro_batch_size": 10,
|
||||
"num_epochs": 1,
|
||||
"sequence_parallel_degree": 1,
|
||||
}
|
||||
)
|
||||
normalize_config(self.cfg_1)
|
||||
|
||||
def test_prepare_dataset_with_deduplication_train(self):
|
||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||
|
||||
Reference in New Issue
Block a user