Phi2 rewrite (#1058)

* restore to current phi modeling code from phi-2

* enable gradient checkpointing

* don't cast everything to float32 all the time

* gradient checkpointing for phi2 ParallelBlock module too

* fix enabling flash attn for phi2

* add comment about import

* fix phi2 example

* fix model type check for tokenizer

* revert float32 -> bf16 casting changes

* support fused dense flash attn

* fix the repo for flash-attn

* add package name for subdir pkg

* fix the data collator when not using sample packing

* install packaging for pytests in ci

* also fix setup to not install flash attn fused dense subdir if not extras

* split out the fused-dense-lib in extra requires

* don't train w group_by_length for phi

* update integration test to use phi2

* set max steps and save steps for phi e2e tests

* try to workaround ssave issue in ci

* skip phi2 e2e test for now
This commit is contained in:
Wing Lian
2024-01-08 14:04:22 -05:00
committed by GitHub
parent 9ca358b671
commit 732851f105
7 changed files with 230 additions and 99 deletions

73
examples/phi/phi2-ft.yml Normal file
View File

@@ -0,0 +1,73 @@
base_model: microsoft/phi-2
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 2048
sample_packing: false # currently unsupported
pad_to_sequence_len:
adapter:
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embd
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: paged_adamw_8bit
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 1e-5
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"

View File

@@ -12,6 +12,7 @@ fire
PyYAML>=6.0
datasets>=2.15.0
flash-attn==2.3.3
fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib
sentencepiece
wandb
einops

View File

@@ -17,6 +17,7 @@ def parse_requirements():
_dependency_links.append(url)
elif (
"flash-attn" not in line
and "flash-attention" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
@@ -51,6 +52,9 @@ setup(
"flash-attn": [
"flash-attn==2.3.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed",
],

View File

@@ -34,6 +34,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
@@ -843,7 +844,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
return BatchSamplerDataCollatorForSeq2Seq(
if training_args.sample_packing:
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)
return DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,

View File

@@ -9,27 +9,32 @@ from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
from .configuration_phi import PhiConfig
try:
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
from flash_attn.ops.fused_dense import FusedDense
except: # noqa: E722
except ImportError:
pad_input, unpad_input = None, None
FlashRotaryEmbedding = None
FlashSelfAttention, FlashCrossAttention = None, None
# this is in a seperate try/except block since sometimes fused_dense isn't available
# and it shouldn't completely disable flash attn when it isn't
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDense = None
@@ -224,7 +229,9 @@ class RotaryEmbedding(nn.Module):
# Initialize cached attributes since ONNX can't rely on dynamic initialization
self._update_cos_sin_cache(
max_position_embeddings, device=device, dtype=torch.float32
max_position_embeddings,
device=device,
dtype=torch.float32,
)
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
@@ -281,34 +288,32 @@ class RotaryEmbedding(nn.Module):
seqlen_offset: int = 0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_start = seqlen_offset
seq_end = seq_start + qkv.shape[1]
if (
self._cos_cached.device != qkv.device
self._seq_len_cached < qkv.shape[1] + seqlen_offset
or self._cos_cached.device != qkv.device
or self._cos_cached.dtype != qkv.dtype
or (self.training and self._cos_cached.is_inference())
):
self._update_cos_sin_cache(
self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype
qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
)
if kv is None:
return _apply_rotary_emb_qkv(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
else:
q = _apply_rotary_emb(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
kv = _apply_rotary_emb_kv(
kv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
return q, kv
@@ -511,7 +516,7 @@ def _update_kv_cache(
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
inference_params.max_batch_size,
inference_params.max_seqlen,
2,
@@ -520,9 +525,6 @@ def _update_kv_cache(
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
@@ -530,8 +532,19 @@ def _update_kv_cache(
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
# When the current sequence length is equal to or larger than the maximum sequence length,
# we need to concatenate the current `kv` with the cached `kv` to expand its length
if sequence_end >= inference_params.max_seqlen:
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
(inference_params.key_value_memory_dict[layer_idx], kv), dim=1
)
inference_params.key_value_memory_dict[layer_idx][
batch_start:batch_end, sequence_start:sequence_end, ...
] = kv
kv = inference_params.key_value_memory_dict[layer_idx][
batch_start:batch_end, :sequence_end, ...
]
return kv
@@ -624,13 +637,10 @@ class MHA(nn.Module):
self.layer_idx = layer_idx
self.return_residual = return_residual
self.checkpointing = checkpointing
self._gradient_checkpointing_func = None
def _forward_self_attn(
self,
x: torch.FloatTensor,
key_padding_mask: Optional[torch.BoolTensor],
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
qkv = rearrange(
@@ -643,20 +653,21 @@ class MHA(nn.Module):
if self.flash_attn:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
if (
key_padding_mask is not None
and cu_seqlens is None
and max_seqlen is None
):
cu_seqlens, max_seqlen = None, None
if key_padding_mask is not None:
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
qkv, indices, cu_seqlens, max_seqlen = unpad_input(
qkv, key_padding_mask
)
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
if self.checkpointing and self.training:
attn_output = self._gradient_checkpointing_func(
self.inner_attn,
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
use_reentrant=False,
)
else:
attn_output = self.inner_attn(
@@ -670,9 +681,12 @@ class MHA(nn.Module):
else attn_output
)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(
self.inner_attn, qkv, key_padding_mask=key_padding_mask
if self.checkpointing and self.training:
return self._gradient_checkpointing_func(
self.inner_attn,
qkv,
key_padding_mask=key_padding_mask,
use_reentrant=False,
)
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
@@ -725,8 +739,8 @@ class MHA(nn.Module):
q, key_padding_mask
)
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
if self.checkpointing and self.training:
attn_output = self._gradient_checkpointing_func(
self.inner_cross_attn,
q,
kv,
@@ -735,6 +749,7 @@ class MHA(nn.Module):
max_seqlen=max_seqlen_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
use_reentrant=False,
)
else:
attn_output = self.inner_cross_attn(
@@ -753,13 +768,14 @@ class MHA(nn.Module):
else attn_output
)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(
if self.checkpointing and self.training:
return self._gradient_checkpointing_func(
self.inner_cross_attn,
q,
kv,
key_padding_mask=key_padding_mask,
causal=causal,
use_reentrant=False,
)
return self.inner_cross_attn(
@@ -771,11 +787,8 @@ class MHA(nn.Module):
x: torch.FloatTensor,
past_key_values: Optional[InferenceParams] = None,
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
if attention_mask is not None:
attention_mask = attention_mask.bool()
else:
@@ -785,18 +798,12 @@ class MHA(nn.Module):
if self.n_head == self.n_head_kv:
if past_key_values is None:
# If `past_key_values` are not supplied, we run self-attention
attn_output = self._forward_self_attn(
x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
attn_output = self._forward_self_attn(x, attention_mask)
else:
# If `past_key_values` are supplied, it means that we might have cached values and
# could take advantage of cross-attention
attn_output = self._forward_cross_attn(
x,
past_key_values,
attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
x, past_key_values, attention_mask
)
# MQA / GQA
else:
@@ -830,6 +837,8 @@ class ParallelBlock(nn.Module):
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)
self.checkpointing = False
self._gradient_checkpointing_func = None
def forward(
self,
@@ -838,23 +847,52 @@ class ParallelBlock(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
def _forward(
mixer,
resid_dropout,
mlp,
ln,
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
past_key_values,
attention_mask,
):
residual = hidden_states
hidden_states = ln(hidden_states)
attn_outputs = mixer(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
attn_outputs = resid_dropout(attn_outputs)
feed_forward_hidden_states = resid_dropout(mlp(hidden_states))
return attn_outputs + feed_forward_hidden_states + residual
if self.training and self.checkpointing:
return self._gradient_checkpointing_func(
_forward,
self.mixer,
self.resid_dropout,
self.mlp,
self.ln,
hidden_states,
past_key_values,
attention_mask,
)
return _forward(
self.mixer,
self.resid_dropout,
self.mlp,
self.ln,
hidden_states,
past_key_values,
attention_mask,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class CausalLMHead(nn.Module):
@@ -911,7 +949,7 @@ class PhiPreTrainedModel(PreTrainedModel):
config_class = PhiConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = False
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
def __init__(self, *inputs, **kwargs) -> None:
@@ -931,6 +969,14 @@ class PhiPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint
):
for module in self.modules():
if hasattr(module, "checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.checkpointing = enable
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
@@ -951,7 +997,7 @@ class PhiPreTrainedModel(PreTrainedModel):
)
else:
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
past_key_values.seqlen_offset = len(input_ids[0]) - 1
past_key_values.seqlen_offset = input_ids.shape[1] - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
@@ -988,8 +1034,6 @@ class PhiModel(PhiPreTrainedModel):
input_ids: torch.LongTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.FloatTensor:
hidden_states = self.embd(input_ids)
@@ -998,8 +1042,6 @@ class PhiModel(PhiPreTrainedModel):
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
return hidden_states
@@ -1034,23 +1076,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
cu_seqlens: Optional[torch.LongTensor] = None
max_seqlen: Optional[int] = None
if position_ids is not None:
batch_size, seq_length = input_ids.shape
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
hidden_states = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
input_ids, past_key_values=past_key_values, attention_mask=attention_mask
)
lm_logits = self.lm_head(hidden_states)

View File

@@ -55,6 +55,8 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
if not model_config_name and cfg.tokenizer_config:
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
try:
@@ -80,6 +82,7 @@ def load_model_config(cfg):
def load_tokenizer(cfg):
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -139,6 +142,7 @@ def load_tokenizer(cfg):
for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and cfg.adapter
@@ -149,6 +153,7 @@ def load_tokenizer(cfg):
for x in ["embed_tokens", "lm_head"]
)
)
and (model_config.model_type in ["llama", "mistral", "mixtral"])
):
raise ValueError(
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
@@ -386,6 +391,10 @@ def load_model(
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
if model_config.model_type == "phi-msft":
model_config.flash_attn = True
model_config.flash_rotary = True
model_config.fused_dense = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -438,11 +447,12 @@ def load_model(
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type == "PhiForCausalLM":
elif model_type == "PhiForCausalLM" or model_config.model_type == "phi-msft":
from axolotl.models.phi import PhiForCausalLM
model = PhiForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -21,17 +23,18 @@ os.environ["WANDB_DISABLED"] = "true"
class TestPhi(unittest.TestCase):
"""
Test case for Llama models using LoRA
Test case for Phi2 models
"""
@pytest.mark.skip(reason="fixme later")
@with_temp_dir
def test_ft(self, temp_dir):
def test_phi2_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "microsoft/phi-2",
"trust_remote_code": True,
"model_type": "PhiForCausalLM",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 512,
"sample_packing": False,
@@ -39,9 +42,6 @@ class TestPhi(unittest.TestCase):
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
@@ -57,9 +57,14 @@ class TestPhi(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"bf16": True,
"flash_attention": True,
"max_steps": 10,
"save_steps": 10,
"eval_steps": 10,
"save_safetensors": True,
}
)
normalize_config(cfg)
@@ -69,12 +74,13 @@ class TestPhi(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
@pytest.mark.skip(reason="multipack no longer supported atm")
@with_temp_dir
def test_ft_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "microsoft/phi-2",
"trust_remote_code": True,
"model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer",