Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
6b617a4fd5 also upgrade accelerate 2024-11-06 08:59:52 -05:00
Wing Lian
6ac10de9ef upgrade liger and transformers 2024-11-06 08:53:03 -05:00
Wing Lian
1b8d439441 add test case 2024-11-05 09:23:08 +07:00
Wing Lian
1ed351781a chore: lint 2024-11-05 09:23:08 +07:00
Wing Lian
c2a48c3a1e add logging 2024-11-05 09:23:08 +07:00
Wing Lian
415399b565 Update README.md
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
67c04133f2 Update src/axolotl/integrations/liger/args.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
4911d0952f skip duplicate code check 2024-11-05 09:23:08 +07:00
Wing Lian
1d7ab52161 update docs and example 2024-11-05 09:23:08 +07:00
Wing Lian
fcdc6fee8b upgrade liger to 0.3.1 2024-11-05 09:23:08 +07:00
10 changed files with 142 additions and 119 deletions

View File

@@ -562,7 +562,8 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
``` ```

View File

@@ -9,7 +9,7 @@ strict: false
plugins: plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2 chat_template: deepseek_v2

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
strict: false strict: false

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.13.2
transformers==4.46.0 transformers==4.46.2
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.44.1 bitsandbytes==0.44.1
accelerate==1.0.1 accelerate==1.1.0
datasets==3.0.1 datasets==3.0.1
deepspeed==0.15.3 deepspeed==0.15.3
pydantic==2.6.3 pydantic==2.6.3
@@ -34,7 +34,7 @@ tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq>=0.2.5 autoawq>=0.2.5
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.3.0 liger-kernel==0.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1

View File

@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics) return super()._save_checkpoint(model, trial)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):

View File

@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training. Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect
import logging import logging
import sys import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):
""" """
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs" return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
if cfg.model_config_type == "llama": if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
from liger_kernel.transformers.model.llama import ( apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
lce_forward as llama_lce_forward, liger_fn_sig = inspect.signature(apply_liger_fn)
) kwargs = {}
from transformers.models.llama import modeling_llama if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if cfg.liger_rope: if "cross_entropy" in liger_fn_sig.parameters:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb kwargs["cross_entropy"] = cfg.liger_cross_entropy
if cfg.liger_rms_norm: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
modeling_llama.LlamaRMSNorm = LigerRMSNorm kwargs[
if cfg.liger_swiglu: "fused_linear_cross_entropy"
modeling_llama.LlamaMLP = LigerSwiGLUMLP ] = cfg.liger_fused_linear_cross_entropy
if cfg.liger_cross_entropy: if "rms_norm" in liger_fn_sig.parameters:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss kwargs["rms_norm"] = cfg.liger_rms_norm
elif cfg.liger_fused_linear_cross_entropy: if "layer_norm" in liger_fn_sig.parameters:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
elif cfg.model_config_type == "mistral": kwargs["geglu"] = cfg.liger_glu_activation
from liger_kernel.transformers.model.mistral import ( elif "swiglu" in liger_fn_sig.parameters:
lce_forward as mistral_lce_forward, kwargs["swiglu"] = cfg.liger_glu_activation
) with zero_only():
from transformers.models.mistral import modeling_mistral LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
) )
if cfg.liger_swiglu: apply_liger_fn(**kwargs)
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba": elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba from transformers.models.jamba import modeling_jamba
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu: if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2": elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.") logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu: if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,9 +15,12 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel): class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -1,7 +1,6 @@
""" """
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
import unittest import unittest
from pathlib import Path from pathlib import Path

View File

View File

@@ -0,0 +1,80 @@
"""
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
}
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
validate_config(test_cfg)