diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml new file mode 100644 index 000000000..b55646df7 --- /dev/null +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -0,0 +1,67 @@ +base_model: deepseek-ai/DeepSeek-V2-Lite +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d4c1ad9a4..2a3e95163 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,10 +19,11 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ import logging +import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP -from liger_kernel.transformers.model.llama import lce_forward +from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -53,7 +54,7 @@ class LigerPlugin(BasePlugin): if cfg.liger_cross_entropy: modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss elif cfg.liger_fused_linear_cross_entropy: - modeling_llama.LlamaForCausalLM.forward = lce_forward + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward elif cfg.model_config_type == "mistral": from transformers.models.mistral import modeling_mistral @@ -102,3 +103,45 @@ class LigerPlugin(BasePlugin): modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + + elif cfg.model_config_type == "qwen2": + from liger_kernel.transformers.model.qwen2 import ( + lce_forward as qwen2_lce_forward, + ) + from transformers.models.qwen2 import modeling_qwen2 + + if cfg.liger_rope: + modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + logging.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_cross_entropy: + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py new file mode 100644 index 000000000..79fb27436 --- /dev/null +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -0,0 +1,127 @@ +""" +DeepseekV2 model with LigerFusedLinearCrossEntropyLoss +""" +# pylint: disable=duplicate-code + +from typing import List, Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + + +# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) +# @replace_return_docstrings( +# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +# ) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py new file mode 100644 index 000000000..dfc3e29c5 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py @@ -0,0 +1,51 @@ +"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk""" + +import importlib +import os +import sys +import typing +from pathlib import Path + +from transformers.file_utils import HF_MODULES_CACHE + + +def _patched_get_class_in_module( + class_name: str, module_path: typing.Union[str, os.PathLike] +) -> typing.Type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_spec = importlib.util.spec_from_file_location( + name, location=Path(HF_MODULES_CACHE) / module_path + ) + module = sys.modules.get(name) + if module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + # load in initial case only + module_spec.loader.exec_module(module) + return getattr(module, class_name) + + +def patch_transformers_dynamic_module_utils(): + """ + Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True. + This causes monkey-patches for multipack and liger to be removed. + We replace the original function with a version that does not reload the module from disk. + See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581 + """ + import transformers + + transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index e43c58650..f29f21be7 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: max_num = int(torch.max(attention_mask).item()) batch_size, _ = attention_mask.shape counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - for i in range(1, max_num + 1): mask = attention_mask == i counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) - result = counts.flatten() nonzero_indices = torch.nonzero(result).squeeze(-1) return result[nonzero_indices] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 01abf5483..c746ccaf7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,6 +43,9 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.transformers_dynamic_module_utils import ( + patch_transformers_dynamic_module_utils, +) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -54,6 +57,8 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod LOG = logging.getLogger("axolotl") +patch_transformers_dynamic_module_utils() + # copied from accelerator.FullyShardedDataParallelPlugin def get_module_class_from_name(module, name):