Compare commits
10 Commits
fsdp2_fp32
...
6b617a4fd5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b617a4fd5 | ||
|
|
6ac10de9ef | ||
|
|
1b8d439441 | ||
|
|
1ed351781a | ||
|
|
c2a48c3a1e | ||
|
|
415399b565 | ||
|
|
67c04133f2 | ||
|
|
4911d0952f | ||
|
|
1d7ab52161 | ||
|
|
fcdc6fee8b |
@@ -562,7 +562,8 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ strict: false
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
chat_template: deepseek_v2
|
||||
|
||||
@@ -4,7 +4,7 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.0
|
||||
transformers==4.46.2
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.0.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
pydantic==2.6.3
|
||||
@@ -34,7 +34,7 @@ tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.3.0
|
||||
liger-kernel==0.4.0
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
|
||||
@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
def _save_checkpoint(self, model, trial):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||
return super()._save_checkpoint(model, trial)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
|
||||
@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
"""
|
||||
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.model_config_type == "llama":
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "mistral":
|
||||
from liger_kernel.transformers.model.mistral import (
|
||||
lce_forward as mistral_lce_forward,
|
||||
)
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma":
|
||||
from liger_kernel.transformers.model.gemma import (
|
||||
lce_forward as gemma_lce_forward,
|
||||
)
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma.GemmaRMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
kwargs = {}
|
||||
if "rope" in liger_fn_sig.parameters:
|
||||
kwargs["rope"] = cfg.liger_rope
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||
if "geglu" in liger_fn_sig.parameters:
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||
|
||||
apply_liger_fn(**kwargs)
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
|
||||
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "qwen2":
|
||||
from liger_kernel.transformers.model.qwen2 import (
|
||||
lce_forward as qwen2_lce_forward,
|
||||
)
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma2":
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
logging.warning(
|
||||
"Fused linear cross entropy is not supported for Gemma 2."
|
||||
)
|
||||
|
||||
elif cfg.model_config_type == "phi3":
|
||||
from liger_kernel.transformers.model.phi3 import (
|
||||
lce_forward as phi3_lce_forward,
|
||||
)
|
||||
from transformers.models.phi3 import modeling_phi3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||
|
||||
@@ -15,9 +15,12 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
if data.get("liger_swiglu") is not None:
|
||||
if data.get("liger_glu_activation") is not None:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||
"Please use 'liger_glu_activation' instead."
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
0
tests/integrations/__init__.py
Normal file
0
tests/integrations/__init__.py
Normal file
80
tests/integrations/liger.py
Normal file
80
tests/integrations/liger.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
config validation tests for swiglu args
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_base_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
assert (
|
||||
"The 'liger_swiglu' argument is deprecated"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
assert updated_cfg.liger_swiglu is None
|
||||
assert updated_cfg.liger_glu_activations is False
|
||||
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
"liger_glu_activations": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
Reference in New Issue
Block a user