Compare commits
7 Commits
seq-parall
...
upgrade_li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bf20f990 | ||
|
|
bb648cbc63 | ||
|
|
8b0bca4842 | ||
|
|
d36baf44b1 | ||
|
|
16c8140d20 | ||
|
|
21c25cf7bc | ||
|
|
32288a5d3c |
@@ -121,7 +121,7 @@ Features:
|
|||||||
|
|
||||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||||
|
|
||||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
|
# fastchat conversation (deprecation soon, use chat_template)
|
||||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
- path: ...
|
- path: ...
|
||||||
type: sharegpt
|
type: sharegpt
|
||||||
@@ -562,7 +562,8 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_swiglu: true
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -4,32 +4,26 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_swiglu: true
|
liger_glu_activation: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
chat_template: llama3
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: mlabonne/FineTome-100k
|
- path: tatsu-lab/alpaca
|
||||||
type: chat_template
|
type: alpaca
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project: check_liger_hf_GA_llama_fix-3
|
||||||
wandb_entity:
|
wandb_entity: axolotl-ai
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name: pr/fix333-tr4.46.1
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.0
|
transformers==4.46.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.0.1
|
accelerate==1.0.1
|
||||||
@@ -33,8 +33,8 @@ gradio==3.50.2
|
|||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=3.1.0
|
||||||
liger-kernel==0.3.0
|
liger-kernel==0.3.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from ...utils.distributed import zero_only
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
"""
|
"""
|
||||||
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
if cfg.model_config_type == "llama":
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
from liger_kernel.transformers.model.llama import (
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
lce_forward as llama_lce_forward,
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
)
|
kwargs = {}
|
||||||
from transformers.models.llama import modeling_llama
|
if "rope" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rope"] = cfg.liger_rope
|
||||||
if cfg.liger_rope:
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
if cfg.liger_rms_norm:
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
kwargs[
|
||||||
if cfg.liger_swiglu:
|
"fused_linear_cross_entropy"
|
||||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
] = cfg.liger_fused_linear_cross_entropy
|
||||||
if cfg.liger_cross_entropy:
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||||
|
if "geglu" in liger_fn_sig.parameters:
|
||||||
elif cfg.model_config_type == "mistral":
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
from liger_kernel.transformers.model.mistral import (
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
lce_forward as mistral_lce_forward,
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
)
|
with zero_only():
|
||||||
from transformers.models.mistral import modeling_mistral
|
LOG.info(
|
||||||
|
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "gemma":
|
|
||||||
from liger_kernel.transformers.model.gemma import (
|
|
||||||
lce_forward as gemma_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_gemma.GemmaRMSNorm = partial(
|
|
||||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
||||||
)
|
)
|
||||||
if cfg.liger_swiglu:
|
apply_liger_fn(**kwargs)
|
||||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "jamba":
|
elif cfg.model_config_type == "jamba":
|
||||||
from transformers.models.jamba import modeling_jamba
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|
||||||
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "qwen2":
|
|
||||||
from liger_kernel.transformers.model.qwen2 import (
|
|
||||||
lce_forward as qwen2_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2 import modeling_qwen2
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "gemma2":
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
|
||||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
||||||
)
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
logging.warning(
|
|
||||||
"Fused linear cross entropy is not supported for Gemma 2."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "phi3":
|
|
||||||
from liger_kernel.transformers.model.phi3 import (
|
|
||||||
lce_forward as phi3_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.phi3 import modeling_phi3
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
||||||
|
|||||||
@@ -15,9 +15,12 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||||
|
|
||||||
|
|
||||||
class LigerArgs(BaseModel):
|
class LigerArgs(BaseModel):
|
||||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: Optional[bool] = None
|
liger_rope: Optional[bool] = None
|
||||||
liger_rms_norm: Optional[bool] = None
|
liger_rms_norm: Optional[bool] = None
|
||||||
|
liger_layer_norm: Optional[bool] = None
|
||||||
liger_swiglu: Optional[bool] = None
|
liger_swiglu: Optional[bool] = None
|
||||||
|
liger_glu_activation: Optional[bool] = None
|
||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_deprecated_swiglu(cls, data):
|
||||||
|
if data.get("liger_swiglu") is not None:
|
||||||
|
if data.get("liger_glu_activation") is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||||
|
"Please use 'liger_glu_activation' instead."
|
||||||
|
)
|
||||||
|
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||||
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user