support for true batches with multipack (#1230)

* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
Wing Lian
2024-02-01 10:18:42 -05:00
committed by GitHub
parent c67fb71583
commit 00568c1539
24 changed files with 573 additions and 246 deletions

View File

@@ -11,7 +11,6 @@ import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -24,6 +23,11 @@ from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -124,7 +128,7 @@ def train(
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
@@ -149,7 +153,10 @@ def train(
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
@@ -195,7 +202,7 @@ def train(
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)