Compare commits
1 Commits
mixtral_sw
...
20231212-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb4a782ce |
@@ -23,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
|||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import add_defaults, normalize_config, validate_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
@@ -301,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
add_defaults(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,4 @@ Custom modeling code for mixtral
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .configuration_moe_mistral import MixtralConfig # noqa
|
from .configuration_moe_mistral import MixtralConfig # noqa
|
||||||
from .modeling_moe_mistral import ( # noqa
|
from .modeling_moe_mistral import MixtralForCausalLM # noqa
|
||||||
MixtralForCausalLM,
|
|
||||||
replace_mixtral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -46,9 +46,8 @@ from transformers.utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
|
|
||||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from .configuration_moe_mistral import MixtralConfig
|
from .configuration_moe_mistral import MixtralConfig
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
@@ -69,61 +68,6 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "MixtralConfig"
|
_CONFIG_FOR_DOC = "MixtralConfig"
|
||||||
|
|
||||||
|
|
||||||
def replace_mixtral_mlp_with_swiglu(model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, FeedForward):
|
|
||||||
mlp = FusedMLP(
|
|
||||||
module.config,
|
|
||||||
module.gate_proj,
|
|
||||||
module.up_proj,
|
|
||||||
module.down_proj,
|
|
||||||
)
|
|
||||||
set_module_name(model, name, mlp)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMLP(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
gate_proj: torch.nn.Linear,
|
|
||||||
up_proj: torch.nn.Linear,
|
|
||||||
down_proj: torch.nn.Linear,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.swiglu = SwiGLU(
|
|
||||||
in_features=config.hidden_size,
|
|
||||||
hidden_features=config.intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
_pack_weights=True,
|
|
||||||
)
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.swiglu.w12.weight.data = torch.cat(
|
|
||||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
|
||||||
)
|
|
||||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
|
||||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the split weights back to the original layers
|
|
||||||
new_mlp = FeedForward(self.config)
|
|
||||||
new_mlp.w1.weight.data = w1
|
|
||||||
new_mlp.w2.weight.data = w2
|
|
||||||
new_mlp.w3.weight.data = self.swiglu.w3.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_mlp)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
|
||||||
return self.swiglu(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
def _get_unpad_data(attention_mask):
|
def _get_unpad_data(attention_mask):
|
||||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
|||||||
@@ -41,6 +41,16 @@ def choose_device(cfg):
|
|||||||
cfg.device_map = None
|
cfg.device_map = None
|
||||||
|
|
||||||
|
|
||||||
|
def add_defaults(cfg):
|
||||||
|
# setup sane defaults if left unspecified
|
||||||
|
if cfg.dataloader_num_workers is None:
|
||||||
|
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
if cfg.dataloader_prefetch_factor is None:
|
||||||
|
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
|
||||||
|
if cfg.dataloader_pin_memory is None:
|
||||||
|
cfg.dataloader_pin_memory = True
|
||||||
|
|
||||||
|
|
||||||
def normalize_config(cfg):
|
def normalize_config(cfg):
|
||||||
# setup some derived config / hyperparams
|
# setup some derived config / hyperparams
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
|||||||
@@ -373,10 +373,7 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type == "MixtralForCausalLM":
|
elif model_type == "MixtralForCausalLM":
|
||||||
from axolotl.models.mixtral import (
|
from axolotl.models.mixtral import MixtralForCausalLM
|
||||||
MixtralForCausalLM,
|
|
||||||
replace_mixtral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = MixtralForCausalLM.from_pretrained(
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -384,11 +381,6 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_mlp:
|
|
||||||
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
|
||||||
replace_mixtral_mlp_with_swiglu(model)
|
|
||||||
|
|
||||||
elif model_type == "MambaLMHeadModel":
|
elif model_type == "MambaLMHeadModel":
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|||||||
Reference in New Issue
Block a user