Mixtral multipack (#928)
* mixtral multipack * use mixtral model * sample yml * calculate cu_seqlens properly * use updated flash ettention setting * attn var checks * force use of flash attention 2 for packing * lint * disable future fix for now * update support table
This commit is contained in:
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-axolotl.models.mixtral.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -65,19 +65,21 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
78
examples/mistral/mixtral.yml
Normal file
78
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
base_model: DiscoResearch/mixtral-7b-8expert
|
||||||
|
model_type: MixtralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
#lora_target_modules:
|
||||||
|
# - gate
|
||||||
|
# - q_proj
|
||||||
|
# - k_proj
|
||||||
|
# - v_proj
|
||||||
|
# - o_proj
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - w3
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps:
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero2.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
6
src/axolotl/models/mixtral/__init__.py
Normal file
6
src/axolotl/models/mixtral/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Custom modeling code for mixtral
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .configuration_moe_mistral import MixtralConfig # noqa
|
||||||
|
from .modeling_moe_mistral import MixtralForCausalLM # noqa
|
||||||
154
src/axolotl/models/mixtral/configuration_moe_mistral.py
Normal file
154
src/axolotl/models/mixtral/configuration_moe_mistral.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Mistral model configuration"""
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
|
||||||
|
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
|
||||||
|
|
||||||
|
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||||
|
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MistralModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
||||||
|
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
|
||||||
|
allows sequence of up to 4096*32 tokens.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
sliding_window (`int`, *optional*, defaults to 4096):
|
||||||
|
Sliding window attention window size. If not specified, will default to `4096`.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MistralModel, MistralConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Mistral 7B style configuration
|
||||||
|
>>> configuration = MixtralConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the Mistral 7B style configuration
|
||||||
|
>>> model = MixtralModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "mistral"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_experts_per_token=2,
|
||||||
|
num_experts=8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.num_experts_per_token = num_experts_per_token
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
1506
src/axolotl/models/mixtral/modeling_moe_mistral.py
Normal file
1506
src/axolotl/models/mixtral/modeling_moe_mistral.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -54,18 +54,25 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
try:
|
model_type = cfg.model_type
|
||||||
model_config = AutoConfig.from_pretrained(
|
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
if model_type == "MixtralForCausalLM":
|
||||||
)
|
from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
|
||||||
except ValueError as err:
|
|
||||||
if "mamba" in model_config_name:
|
model_config = MixtralConfig.from_pretrained(model_config_name)
|
||||||
return addict.Dict(
|
else:
|
||||||
{
|
try:
|
||||||
"model_type": "mamba",
|
model_config = AutoConfig.from_pretrained(
|
||||||
}
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
raise err
|
except ValueError as err:
|
||||||
|
if "mamba" in model_config_name:
|
||||||
|
return addict.Dict(
|
||||||
|
{
|
||||||
|
"model_type": "mamba",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
raise err
|
||||||
|
|
||||||
if cfg.model_config:
|
if cfg.model_config:
|
||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
@@ -301,7 +308,9 @@ def load_model(
|
|||||||
or cfg.is_falcon_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
):
|
):
|
||||||
model_kwargs["use_flash_attention_2"] = True
|
# TODO enable once properly supported in transformers
|
||||||
|
# model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
@@ -363,6 +372,15 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
elif model_type == "MixtralForCausalLM":
|
||||||
|
from axolotl.models.mixtral import MixtralForCausalLM
|
||||||
|
|
||||||
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
elif model_type == "MambaLMHeadModel":
|
elif model_type == "MambaLMHeadModel":
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|||||||
Reference in New Issue
Block a user