Multipack simplify for Mixtral (#1142)
This commit is contained in:
@@ -12,7 +12,7 @@ from abc import abstractmethod
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -37,6 +37,7 @@ from axolotl.utils.collators import (
|
|||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
@@ -896,14 +897,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if is_eval and training_args.eval_sample_packing:
|
if is_eval and training_args.eval_sample_packing:
|
||||||
use_batch_sampler_collator = True
|
use_batch_sampler_collator = True
|
||||||
|
|
||||||
|
collator: Type[
|
||||||
|
Union[
|
||||||
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
]
|
||||||
|
]
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
return BatchSamplerDataCollatorForSeq2Seq(
|
if self.cfg.model_config_type == "mixtral":
|
||||||
self.tokenizer,
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
return_tensors="pt",
|
else:
|
||||||
**kwargs,
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
)
|
else:
|
||||||
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
return DataCollatorForSeq2Seq(
|
return collator(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -3,20 +3,10 @@ Patches to support multipack for mixtral
|
|||||||
"""
|
"""
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
|
|
||||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||||
from .modeling_mixtral import (
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
MixtralMultipackFlashAttention2,
|
get_unpad_data
|
||||||
mixtral_decoder_layer_forward,
|
|
||||||
mixtral_model_forward,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
|
||||||
mixtral_decoder_layer_forward
|
|
||||||
)
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
|
||||||
mixtral_model_forward
|
|
||||||
)
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
|
||||||
"flash_attention_2"
|
|
||||||
] = MixtralMultipackFlashAttention2
|
|
||||||
|
|||||||
@@ -1,383 +0,0 @@
|
|||||||
"""
|
|
||||||
Mixtral modeling for multipack
|
|
||||||
"""
|
|
||||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
|
||||||
from transformers import Cache, DynamicCache
|
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
||||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
|
||||||
MixtralFlashAttention2,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
|
||||||
"""
|
|
||||||
Custom multipack implementation w flash attention 2
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._flash_attn_uses_top_left_mask = True
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
if self.layer_idx is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
||||||
"with a layer index."
|
|
||||||
)
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=self.attention_dropout,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
def mixtral_decoder_layer_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
output_router_logits: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_router_logits (`bool`, *optional*):
|
|
||||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
||||||
should not be returned during inference.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
if output_router_logits:
|
|
||||||
outputs += (router_logits,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def mixtral_model_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
if input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
|
||||||
if use_legacy_cache:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
||||||
|
|
||||||
cu_seqlens = None
|
|
||||||
max_seqlen = None
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_seqlens = cu_seqlens.squeeze()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
if (
|
|
||||||
attention_mask is not None
|
|
||||||
and self._attn_implementation == "flash_attention_2"
|
|
||||||
and use_cache
|
|
||||||
):
|
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
||||||
if is_padding_right:
|
|
||||||
raise ValueError(
|
|
||||||
"You are attempting to perform batched generation with padding_side='right'"
|
|
||||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
|
||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._attn_implementation == "flash_attention_2":
|
|
||||||
# 2d mask is passed through the layers
|
|
||||||
attention_mask = (
|
|
||||||
attention_mask
|
|
||||||
if (attention_mask is not None and 0 in attention_mask)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
LOG.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
all_router_logits = () if output_router_logits else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
output_router_logits,
|
|
||||||
use_cache,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
if output_router_logits:
|
|
||||||
all_router_logits += (layer_outputs[-1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = None
|
|
||||||
if use_cache:
|
|
||||||
next_cache = (
|
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if use_legacy_cache
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
next_cache,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attns,
|
|
||||||
all_router_logits,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return MoeModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
router_logits=all_router_logits,
|
|
||||||
)
|
|
||||||
@@ -2,6 +2,40 @@
|
|||||||
Shared utils for the monkeypatches
|
Shared utils for the monkeypatches
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
max_num = int(torch.max(attention_mask).item())
|
||||||
|
batch_size, _ = attention_mask.shape
|
||||||
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||||
|
|
||||||
|
for i in range(1, max_num + 1):
|
||||||
|
mask = attention_mask == i
|
||||||
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
result = counts.flatten()
|
||||||
|
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||||
|
return result[nonzero_indices]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def get_unpad_data(attention_mask: torch.Tensor):
|
||||||
|
device = attention_mask.device
|
||||||
|
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = (
|
||||||
|
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
.to(device=device)
|
||||||
|
.detach()
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens(attn_mask):
|
def get_cu_seqlens(attn_mask):
|
||||||
|
|||||||
@@ -152,6 +152,33 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to the using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
chunked_data = {}
|
||||||
|
for feature in features[0].keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(i + 1) * np.array(item[feature])
|
||||||
|
for i, item in enumerate(features)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
features = [chunked_data]
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MambaDataCollator:
|
class MambaDataCollator:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for working with config dicts"""
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -135,7 +137,7 @@ def normalize_config(cfg):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
or "mistral" in cfg.base_model.lower()
|
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -484,6 +486,40 @@ def validate_config(cfg):
|
|||||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.unfrozen_parameters
|
||||||
|
and cfg.gradient_checkpointing_kwargs
|
||||||
|
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||||
|
):
|
||||||
|
# https://github.com/huggingface/transformers/issues/21381
|
||||||
|
raise ValueError(
|
||||||
|
"`use_reentrant` must be false when used with partially frozen model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||||
|
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||||
|
contents = file.read()
|
||||||
|
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||||
|
if (
|
||||||
|
deepspeed_cfg.zero_optimization
|
||||||
|
and deepspeed_cfg.zero_optimization.stage == 3
|
||||||
|
):
|
||||||
|
if not (
|
||||||
|
(
|
||||||
|
deepspeed_cfg.bf16
|
||||||
|
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
deepspeed_cfg.fp16
|
||||||
|
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -305,12 +305,16 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Modify mistral derived models
|
# Modify mistral derived models
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if (
|
||||||
|
cfg.model_config_type == "mistral"
|
||||||
|
and cfg.flash_attention
|
||||||
|
and cfg.sample_packing
|
||||||
|
):
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
replace_mistral_attn_with_flash_attn,
|
replace_mistral_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching mistral with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -322,7 +326,7 @@ def load_model(
|
|||||||
replace_mixtral_attn_with_multipack_flash_attn,
|
replace_mixtral_attn_with_multipack_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching mixtral with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
||||||
or cfg.model_config_type == "mamba"
|
or cfg.model_config_type == "mamba"
|
||||||
):
|
):
|
||||||
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
"MixtralFlashAttention2"
|
||||||
in model.model.layers[0].self_attn.__class__.__module__
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"MixtralMultipackFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
in model.model.layers[0].self_attn.__class__.__name__
|
||||||
)
|
)
|
||||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
|||||||
@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
"MixtralFlashAttention2"
|
||||||
in model.model.layers[0].self_attn.__class__.__module__
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"MixtralMultipackFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
in model.model.layers[0].self_attn.__class__.__name__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import (
|
||||||
|
get_cu_seqlens,
|
||||||
|
get_cu_seqlens_from_pos_ids,
|
||||||
|
get_max_seqlen_in_batch,
|
||||||
|
get_unpad_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMonkeyPatchUtils(unittest.TestCase):
|
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||||
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
|||||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_max_seqlen_in_batch(self):
|
||||||
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||||
|
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||||
|
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
|
||||||
|
|
||||||
|
def test_get_unpad_data(self):
|
||||||
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||||
|
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
||||||
|
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
|
||||||
|
target_max_seqlen_in_batch = 5
|
||||||
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
||||||
|
self.assertTrue(torch.allclose(target_indices, indices))
|
||||||
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
||||||
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
||||||
|
|
||||||
|
attn_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
|
||||||
|
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
target_indices = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
target_cu_seqlen = torch.tensor(
|
||||||
|
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
|
||||||
|
)
|
||||||
|
target_max_seqlen_in_batch = 5
|
||||||
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
||||||
|
self.assertTrue(torch.allclose(target_indices, indices))
|
||||||
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
||||||
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user