Mixtral fixes 20240124 (#1192) [skip ci]

* mixtral nccl fixes

* make sure to patch for z3
This commit is contained in:
Wing Lian
2024-01-24 14:59:57 -05:00
committed by GitHub
parent af0243021c
commit 54d2ac155b
14 changed files with 71 additions and 13 deletions

View File

@@ -861,7 +861,7 @@ tokens:
fsdp: fsdp:
fsdp_config: fsdp_config:
# Deepspeed config path. e.g., deepspeed/zero3.json # Deepspeed config path. e.g., deepspeed_configs/zero3.json
deepspeed: deepspeed:
# Advanced DDP Arguments # Advanced DDP Arguments
@@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3. We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml ```yaml
deepspeed: deepspeed/zero1.json deepspeed: deepspeed_configs/zero1.json
``` ```
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
``` ```
##### FSDP ##### FSDP

View File

@@ -62,7 +62,7 @@ evals_per_epoch: 4
eval_table_size: eval_table_size:
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:

View File

@@ -942,7 +942,7 @@
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n", "not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
"For more information read axolotl's readme\n", "For more information read axolotl's readme\n",
"\"\"\"\n", "\"\"\"\n",
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json" "!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
] ]
} }
], ],

View File

@@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3 #default deepspeed, can use more aggresive if needed like zero2, zero3
deepspeed: deepspeed/zero1.json deepspeed: deepspeed_configs/zero1.json
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:

View File

@@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml
If you run into CUDA OOM, use deepspeed with config zero2.json: If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell ```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
``` ```

View File

@@ -84,7 +84,7 @@ eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed/zero2.json deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:

View File

@@ -3,7 +3,7 @@
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune. Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
```shell ```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json
# OR # OR

View File

@@ -1,12 +1,61 @@
""" """
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch
import transformers import transformers
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
def replace_mixtral_attn_with_multipack_flash_attn(): def patch_mixtral_moe_forward_zero3() -> None:
import torch.nn.functional as F
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = ( # pylint: disable=invalid-name
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
from transformers.models.mixtral.modeling_mixtral import (
MixtralBLockSparseTop2MLP,
MixtralSparseMoeBlock,
)
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -15,7 +15,7 @@ from optimum.bettertransformer import BetterTransformer
from peft import PeftModel from peft import PeftModel
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging

View File

@@ -21,7 +21,7 @@ from transformers import ( # noqa: F401
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
@@ -333,7 +333,10 @@ def load_model(
) )
LOG.info("patching mixtral with flash attention") LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn() mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing: if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import ( from axolotl.monkeypatch.falcon import (
@@ -646,6 +649,12 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True