Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
67
examples/deepseek-v2/fft-fsdp-16b.yaml
Normal file
67
examples/deepseek-v2/fft-fsdp-16b.yaml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: deepseek-ai/DeepSeek-V2-Lite
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
@@ -19,10 +19,11 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
|||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
from liger_kernel.transformers.model.llama import lce_forward
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
@@ -53,7 +54,7 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_llama.LlamaForCausalLM.forward = lce_forward
|
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "mistral":
|
elif cfg.model_config_type == "mistral":
|
||||||
from transformers.models.mistral import modeling_mistral
|
from transformers.models.mistral import modeling_mistral
|
||||||
@@ -102,3 +103,45 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "qwen2":
|
||||||
|
from liger_kernel.transformers.model.qwen2 import (
|
||||||
|
lce_forward as qwen2_lce_forward,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2 import modeling_qwen2
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
||||||
|
)
|
||||||
|
modeling_mod = sys.modules[model.__class__.__module__]
|
||||||
|
|
||||||
|
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|||||||
127
src/axolotl/integrations/liger/models/deepseekv2.py
Normal file
127
src/axolotl/integrations/liger/models/deepseekv2.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
||||||
|
# @replace_return_docstrings(
|
||||||
|
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
# )
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
|
||||||
|
|
||||||
|
>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
# flatten tokens
|
||||||
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
|
||||||
|
lce = LigerFusedLinearCrossEntropyLoss()
|
||||||
|
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
51
src/axolotl/monkeypatch/transformers_dynamic_module_utils.py
Normal file
51
src/axolotl/monkeypatch/transformers_dynamic_module_utils.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.file_utils import HF_MODULES_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_get_class_in_module(
|
||||||
|
class_name: str, module_path: typing.Union[str, os.PathLike]
|
||||||
|
) -> typing.Type:
|
||||||
|
"""
|
||||||
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name (`str`): The name of the class to import.
|
||||||
|
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`typing.Type`: The class looked for.
|
||||||
|
"""
|
||||||
|
name = os.path.normpath(module_path)
|
||||||
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_spec = importlib.util.spec_from_file_location(
|
||||||
|
name, location=Path(HF_MODULES_CACHE) / module_path
|
||||||
|
)
|
||||||
|
module = sys.modules.get(name)
|
||||||
|
if module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
# load in initial case only
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers_dynamic_module_utils():
|
||||||
|
"""
|
||||||
|
Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True.
|
||||||
|
This causes monkey-patches for multipack and liger to be removed.
|
||||||
|
We replace the original function with a version that does not reload the module from disk.
|
||||||
|
See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module
|
||||||
@@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
|||||||
max_num = int(torch.max(attention_mask).item())
|
max_num = int(torch.max(attention_mask).item())
|
||||||
batch_size, _ = attention_mask.shape
|
batch_size, _ = attention_mask.shape
|
||||||
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||||
|
|
||||||
for i in range(1, max_num + 1):
|
for i in range(1, max_num + 1):
|
||||||
mask = attention_mask == i
|
mask = attention_mask == i
|
||||||
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||||
|
|
||||||
result = counts.flatten()
|
result = counts.flatten()
|
||||||
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||||
return result[nonzero_indices]
|
return result[nonzero_indices]
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
)
|
)
|
||||||
|
from axolotl.monkeypatch.transformers_dynamic_module_utils import (
|
||||||
|
patch_transformers_dynamic_module_utils,
|
||||||
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
@@ -54,6 +57,8 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
patch_transformers_dynamic_module_utils()
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
def get_module_class_from_name(module, name):
|
def get_module_class_from_name(module, name):
|
||||||
|
|||||||
Reference in New Issue
Block a user