From 40a6362c9256a462a0f44539233b7b54fd54de11 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 Dec 2023 12:10:41 -0500 Subject: [PATCH] support for mamba (#915) * support for mamba * more mamba fixes * use fork for mamba kwargs fix * grad checkpointing doesn't work * fix extras for mamaba * mamba loss fix * use fp32 and remove verbose logging * mamba fixes * fix collator for mamba * set model_type on training_args * don't save safetensors for mamba * update mamba config to disable safetensor checkpooints, install for tests * no evals for mamba tests * handle save_pretrained * handle unused safetensors arg --- .github/workflows/tests.yml | 2 +- examples/mamba/config.yml | 61 +++++++++ setup.py | 3 + src/axolotl/core/trainer_builder.py | 55 +++++++- src/axolotl/models/mamba/__init__.py | 12 ++ .../models/mamba/configuration_mamba.py | 42 ++++++ src/axolotl/models/mamba/modeling_mamba.py | 128 ++++++++++++++++++ src/axolotl/train.py | 6 +- src/axolotl/utils/collators.py | 34 ++++- src/axolotl/utils/models.py | 51 +++++-- src/axolotl/utils/trainer.py | 12 +- tests/e2e/test_mamba.py | 65 +++++++++ 12 files changed, 447 insertions(+), 24 deletions(-) create mode 100644 examples/mamba/config.yml create mode 100644 src/axolotl/models/mamba/__init__.py create mode 100644 src/axolotl/models/mamba/configuration_mamba.py create mode 100644 src/axolotl/models/mamba/modeling_mamba.py create mode 100644 tests/e2e/test_mamba.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9103126ce..ad2cb428b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -73,7 +73,7 @@ jobs: run: | pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1 pip3 uninstall -y transformers accelerate - pip3 install -U -e .[flash-attn] + pip3 install -U -e .[flash-attn,mamba-ssm] pip3 install -r requirements-tests.txt - name: Run e2e tests diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml new file mode 100644 index 000000000..c2e5a851f --- /dev/null +++ b/examples/mamba/config.yml @@ -0,0 +1,61 @@ +base_model: state-spaces/mamba-2.8b +model_type: MambaLMHeadModel +tokenizer_type: AutoTokenizer +tokenizer_config: EleutherAI/gpt-neox-20b + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 2048 +sample_packing: false +pad_to_sequence_len: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 5e-5 + +train_on_inputs: false +group_by_length: true + +bf16: true +fp16: false +tf32: true + +gradient_checkpointing: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: + +warmup_steps: 10 +eval_steps: +eval_table_size: +eval_table_max_new_tokens: 128 +save_steps: 0.25 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: +tokens: +save_safetensors: False diff --git a/setup.py b/setup.py index 986160273..42fd22df1 100644 --- a/setup.py +++ b/setup.py @@ -51,5 +51,8 @@ setup( "deepspeed": [ "deepspeed", ], + "mamba-ssm": [ + "mamba-ssm==1.0.1", + ], }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d166691f1..1b037420c 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -31,7 +31,10 @@ from axolotl.utils.callbacks import ( bench_eval_callback_factory, log_prediction_callback_factory, ) -from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.collators import ( + BatchSamplerDataCollatorForSeq2Seq, + MambaDataCollator, +) from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup @@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments): Extend the base TrainingArguments for axolotl helpers """ + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) lr_quadratic_warmup: bool = field( default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, @@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer): return super().compute_loss(model, inputs, return_outputs=return_outputs) +class AxolotlMambaTrainer(AxolotlTrainer): + """ + Mamba specific trainer to handle loss calculation + """ + + def compute_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss + + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler @@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return OneCycleLRSchedulerTrainer if self.cfg.relora_steps: return ReLoRATrainer + if self.cfg.model_config_type == "mamba": + return AxolotlMambaTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.hub_strategy: training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy - if self.cfg.save_safetensors: + if self.cfg.save_safetensors is not None: training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors if self.cfg.sample_packing_eff_est: @@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) + training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, @@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=BatchSamplerDataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), + data_collator=self.build_collator(**data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.micro_batch_size return trainer + + def build_collator(self, **kwargs): + if self.cfg.model_config_type == "mamba": + return MambaDataCollator(tokenizer=self.tokenizer) + + return BatchSamplerDataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **kwargs, + ) diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py new file mode 100644 index 000000000..247c1d184 --- /dev/null +++ b/src/axolotl/models/mamba/__init__.py @@ -0,0 +1,12 @@ +""" +Modeling module for Mamba models +""" + + +def fix_mamba_attn_for_loss(): + from mamba_ssm.models import mixer_seq_simple + + from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed + + mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed + return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name diff --git a/src/axolotl/models/mamba/configuration_mamba.py b/src/axolotl/models/mamba/configuration_mamba.py new file mode 100644 index 000000000..5160ee8d7 --- /dev/null +++ b/src/axolotl/models/mamba/configuration_mamba.py @@ -0,0 +1,42 @@ +""" +HF Transformers MambaConfig +""" +from transformers import PretrainedConfig + + +class MambaConfig(PretrainedConfig): + """ + modeling configuration for state space model/mamba + """ + + model_type = "mamba" + + def __init__( + self, + vocab_size=50280, + d_model=2560, + n_layer=64, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=8, + pad_token_id=50277, + bos_token_id=0, + eos_token_id=0, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.n_layer = n_layer + self.rms_norm = rms_norm + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.pad_vocab_size_multiple = pad_vocab_size_multiple + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py new file mode 100644 index 000000000..70e9c88c8 --- /dev/null +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -0,0 +1,128 @@ +# pylint: skip-file +import os +from collections import namedtuple +from functools import partial +from typing import Optional, Union + +import torch +from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf +from torch import nn +from torch.nn import CrossEntropyLoss + +from axolotl.models.mamba.configuration_mamba import MambaConfig + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + initializer_cfg=None, + pad_vocab_size_multiple: int = 1, + device=None, + dtype=None, + **backbone_kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - ( + vocab_size % pad_vocab_size_multiple + ) + self.config = MambaConfig( + vocab_size=vocab_size, + d_model=d_model, + n_layer=n_layer, + pad_vocab_size_multiple=pad_vocab_size_multiple, + ) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + initializer_cfg=initializer_cfg, + **backbone_kwargs, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) + + def forward( + self, + input_ids, + position_ids=None, + inference_params=None, + num_last_tokens=0, + labels=None, + **kwargs, + ): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + loss = None + if labels is not None: + logits = lm_logits + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"]) + print(loss) + return CausalLMOutput(logits=lm_logits, loss=loss) + + else: + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument + ): + if state_dict is None: + state_dict = self.state_dict() + torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config = load_config_hf(pretrained_model_name) + model = cls(**config, device=device, dtype=dtype, **kwargs) + model.load_state_dict( + load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype) + ) + return model diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 60c76b1b0..022d230cb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -82,7 +82,8 @@ def train( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) - model.config.use_cache = False + if hasattr(model, "config"): + model.config.use_cache = False # go ahead and presave, so we have the adapter config available to inspect if peft_config: @@ -92,7 +93,8 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) tokenizer.save_pretrained(str(Path(cfg.output_dir))) - model.config.save_pretrained(str(Path(cfg.output_dir))) + if hasattr(model, "config"): + model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index ffae3f263..0f0eb5a95 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -2,12 +2,16 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences """ from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import numpy as np +import torch +import transformers from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +IGNORE_INDEX = -100 + @dataclass class DataCollatorForSeq2Seq: @@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) + + +@dataclass +class MambaDataCollator: + """ + Collator for State Space Models (Mamba) + """ + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [torch.LongTensor(instance[key]) for instance in instances] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + + return { + "input_ids": input_ids, + "labels": labels, + } diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6c77ea4c6..a48ffc7a3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -4,6 +4,7 @@ import math import os from typing import Optional, Tuple # noqa: F401 +import addict import bitsandbytes as bnb import torch import transformers @@ -21,6 +22,7 @@ from transformers import ( # noqa: F401 PreTrainedTokenizerBase, ) +from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault @@ -52,9 +54,19 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig): def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True - model_config = AutoConfig.from_pretrained( - model_config_name, trust_remote_code=trust_remote_code - ) + try: + model_config = AutoConfig.from_pretrained( + model_config_name, trust_remote_code=trust_remote_code + ) + except ValueError as err: + if "mamba" in model_config_name: + return addict.Dict( + { + "model_type": "mamba", + } + ) + raise err + if cfg.model_config: for key, val in cfg.model_config.items(): setattr(model_config, key, val) @@ -351,6 +363,20 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) + elif model_type == "MambaLMHeadModel": + # FIXME this is janky at best and hacked together to make it work + MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + + model_kwargs["dtype"] = model_kwargs["torch_dtype"] + model_kwargs["device"] = torch.cuda.current_device() + del model_kwargs["torch_dtype"] + del model_kwargs["device_map"] + del model_kwargs["max_memory"] + + model = MambaLMHeadModel.from_pretrained( + base_model, + **model_kwargs, + ) elif model_type and not cfg.trust_remote_code: if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( @@ -410,13 +436,17 @@ def load_model( if cfg.resize_token_embeddings_to_32x else len(tokenizer) ) - if model.get_input_embeddings().num_embeddings < embeddings_len: + if ( + hasattr(model, "get_input_embeddings") + and model.get_input_embeddings().num_embeddings < embeddings_len + ): model.resize_token_embeddings(embeddings_len) else: model.tie_weights() if ( - hasattr(model.config, "max_position_embeddings") + hasattr(model, "config") + and hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings and cfg.sequence_len > model.config.max_position_embeddings ): @@ -426,20 +456,22 @@ def load_model( model.config.max_position_embeddings = cfg.sequence_len if ( - hasattr(model.config, "bos_token_id") + hasattr(model, "config") + and hasattr(model.config, "bos_token_id") and model.config.bos_token_id and model.config.bos_token_id != tokenizer.bos_token_id ): model.config.bos_token_id = tokenizer.bos_token_id if ( - hasattr(model.config, "eos_token_id") + hasattr(model, "config") + and hasattr(model.config, "eos_token_id") and model.config.eos_token_id and model.config.eos_token_id != tokenizer.eos_token_id ): model.config.eos_token_id = tokenizer.eos_token_id - if model.device.type == "cuda": + if hasattr(model, "device") and model.device.type == "cuda": log_gpu_memory_usage(LOG, "after model load", model.device) # make sure these are fp32 per Ramesh et al. (2021) @@ -498,7 +530,8 @@ def load_model( requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: LOG.warning("there are no parameters that require gradient updates") - model.config.use_cache = False + if hasattr(model, "config"): + model.config.use_cache = False if cfg.flash_optimum: model = BetterTransformer.transform(model) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 469f6d886..590861cc0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): ) # Phi doesn't want the attention_mask feature when training - if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( - cfg.is_mistral_derived_model and cfg.flash_attention + if ( + "CodeGenTokenizer" in tokenizer.__class__.__name__ + or (cfg.is_mistral_derived_model and cfg.flash_attention) + or cfg.model_config_type == "mamba" ): train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: @@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): if update: cfg.total_num_tokens = total_num_tokens - if not cfg.total_supervised_tokens: + skip_estimates = cfg.model_config_type == "mamba" + + if not skip_estimates and not cfg.total_supervised_tokens: total_supervised_tokens = ( train_dataset.data.column("labels") .to_pandas() @@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): if update: cfg.total_supervised_tokens = total_supervised_tokens - if cfg.sample_packing: + if not skip_estimates and cfg.sample_packing: # we have to drop anything longer then sequence len otherwise # flash attention with position ids fails diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py new file mode 100644 index 000000000..463b0ddac --- /dev/null +++ b/tests/e2e/test_mamba.py @@ -0,0 +1,65 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMistral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_fft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "state-spaces/mamba-130m", + "model_type": "MambaLMHeadModel", + "tokenizer_type": "AutoTokenizer", + "tokenizer_config": "EleutherAI/gpt-neox-20b", + "flash_attention": False, + "sequence_len": 1024, + "load_in_8bit": False, + "val_set_size": 0.0, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "gradient_checkpointing": False, + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": None, + "save_safetensors": False, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists()