review comments, docstrings
This commit is contained in:
@@ -34,7 +34,7 @@ resize_token_embeddings_to_32x:
|
|||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
# Whether to load the model with randomly initialized weights. Useful for
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
# pre-training a model from scratch or debugging purposes.
|
# pre-training a model from scratch or debugging purposes.
|
||||||
random_init:
|
random_init_weights:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Ring attention group registration and flash attention patching."""
|
"""
|
||||||
|
Ring attention group registration and flash attention patching.
|
||||||
|
|
||||||
from typing import Any
|
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||||
|
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||||
|
their sequence parallel version of Flash Attention 2.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -14,11 +18,23 @@ LOG = get_logger(__name__)
|
|||||||
RING_ATTN_GROUP = None
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
def get_ring_attn_group() -> Any:
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
|
"""
|
||||||
|
Getter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The process group for ring attention for this rank.
|
||||||
|
"""
|
||||||
return RING_ATTN_GROUP
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group: Any):
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||||
|
"""
|
||||||
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Process group for ring attention.
|
||||||
|
"""
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
@@ -166,35 +165,11 @@ def setup_signal_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_context_manager(
|
|
||||||
flash_optimum: bool = False,
|
|
||||||
) -> contextlib.AbstractContextManager:
|
|
||||||
"""
|
|
||||||
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
flash_optimum: Whether to enable efficient backends for SDP attention.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Context manager for temporarily enabling efficient backends for SDP attention
|
|
||||||
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
|
|
||||||
"""
|
|
||||||
if flash_optimum:
|
|
||||||
return torch.backends.cuda.sdp_kernel(
|
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
|
||||||
enable_flash=True,
|
|
||||||
enable_math=True,
|
|
||||||
enable_mem_efficient=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
|
|
||||||
def execute_training(
|
def execute_training(
|
||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate backend configurations.
|
Execute the training process with appropriate SDP kernel configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -202,8 +177,15 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
|
if cfg.flash_optimum:
|
||||||
with context_manager:
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=True,
|
||||||
|
):
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -922,7 +922,7 @@ class ModelLoader:
|
|||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
|
||||||
# Load model with random initialization if specified
|
# Load model with random initialization if specified
|
||||||
if self.cfg.random_init:
|
if self.cfg.random_init_weights:
|
||||||
# AutoModel classes support the from_config method
|
# AutoModel classes support the from_config method
|
||||||
if self.auto_model_loader in [
|
if self.auto_model_loader in [
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
|
"base_model": "huggyllama/llama-7b",
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"dataset_exact_deduplication": True,
|
"dataset_exact_deduplication": True,
|
||||||
@@ -280,9 +282,9 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"batch_size": 10,
|
"batch_size": 10,
|
||||||
"micro_batch_size": 10,
|
"micro_batch_size": 10,
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"sequence_parallel_degree": 1,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
|
|||||||
Reference in New Issue
Block a user