diff --git a/README.md b/README.md index c12aa3bba..b3f292c7d 100644 --- a/README.md +++ b/README.md @@ -562,7 +562,8 @@ plugins: - axolotl.integrations.liger.LigerPlugin liger_rope: true liger_rms_norm: true -liger_swiglu: true +liger_glu_activation: true +liger_layer_norm: true liger_fused_linear_cross_entropy: true ``` diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 0320e0213..6b8771d81 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -9,7 +9,7 @@ strict: false plugins: - axolotl.integrations.liger.LigerPlugin liger_rms_norm: true -liger_swiglu: true +liger_glu_activation: true liger_fused_linear_cross_entropy: true chat_template: deepseek_v2 diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index 99ba63fcc..043b5c980 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -4,7 +4,7 @@ plugins: - axolotl.integrations.liger.LigerPlugin liger_rope: true liger_rms_norm: true -liger_swiglu: true +liger_glu_activation: true liger_fused_linear_cross_entropy: true strict: false diff --git a/requirements.txt b/requirements.txt index 6bb1aa684..ec823a82a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.2 -transformers==4.46.0 +transformers==4.46.1 tokenizers>=0.20.1 bitsandbytes==0.44.1 -accelerate==1.0.1 +accelerate==1.1.0 datasets==3.0.1 deepspeed==0.15.3 pydantic==2.6.3 @@ -34,7 +34,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel==0.3.0 +liger-kernel==0.4.0 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index aab9a80b8..4fadd7eb4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) - def _save_checkpoint(self, model, trial, metrics=None): + def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial, metrics=metrics) + return super()._save_checkpoint(model, trial, **kwargs) class AxolotlMambaTrainer(AxolotlTrainer): diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 2047f3815..a64d748c6 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl. Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ +import inspect import logging import sys -from functools import partial from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss -from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from axolotl.integrations.base import BasePlugin +from ...utils.distributed import zero_only from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 +LOG = logging.getLogger("axolotl.integrations.liger") + class LigerPlugin(BasePlugin): """ @@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin): return "axolotl.integrations.liger.LigerArgs" def pre_model_load(self, cfg): - if cfg.model_config_type == "llama": - from liger_kernel.transformers.model.llama import ( - lce_forward as llama_lce_forward, - ) - from transformers.models.llama import modeling_llama - - if cfg.liger_rope: - modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_llama.LlamaRMSNorm = LigerRMSNorm - if cfg.liger_swiglu: - modeling_llama.LlamaMLP = LigerSwiGLUMLP - if cfg.liger_cross_entropy: - modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss - elif cfg.liger_fused_linear_cross_entropy: - modeling_llama.LlamaForCausalLM.forward = llama_lce_forward - - elif cfg.model_config_type == "mistral": - from liger_kernel.transformers.model.mistral import ( - lce_forward as mistral_lce_forward, - ) - from transformers.models.mistral import modeling_mistral - - if cfg.liger_rope: - modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_mistral.MistralRMSNorm = LigerRMSNorm - if cfg.liger_swiglu: - modeling_mistral.MistralMLP = LigerSwiGLUMLP - if cfg.liger_cross_entropy: - modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward - - elif cfg.model_config_type == "gemma": - from liger_kernel.transformers.model.gemma import ( - lce_forward as gemma_lce_forward, - ) - from transformers.models.gemma import modeling_gemma - - if cfg.liger_rope: - modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_gemma.GemmaRMSNorm = partial( - LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: + apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] + liger_fn_sig = inspect.signature(apply_liger_fn) + kwargs = {} + if "rope" in liger_fn_sig.parameters: + kwargs["rope"] = cfg.liger_rope + if "cross_entropy" in liger_fn_sig.parameters: + kwargs["cross_entropy"] = cfg.liger_cross_entropy + if "fused_linear_cross_entropy" in liger_fn_sig.parameters: + kwargs[ + "fused_linear_cross_entropy" + ] = cfg.liger_fused_linear_cross_entropy + if "rms_norm" in liger_fn_sig.parameters: + kwargs["rms_norm"] = cfg.liger_rms_norm + if "layer_norm" in liger_fn_sig.parameters: + kwargs["layer_norm"] = cfg.liger_layer_norm + if "geglu" in liger_fn_sig.parameters: + kwargs["geglu"] = cfg.liger_glu_activation + elif "swiglu" in liger_fn_sig.parameters: + kwargs["swiglu"] = cfg.liger_glu_activation + with zero_only(): + LOG.info( + f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" ) - if cfg.liger_swiglu: - modeling_gemma.GemmaMLP = LigerGEGLUMLP - if cfg.liger_cross_entropy: - modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward - + apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba @@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin): modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: modeling_jamba.JambaRMSNorm = LigerRMSNorm - if cfg.liger_swiglu: + if cfg.liger_glu_activation: modeling_jamba.JambaMLP = LigerSwiGLUMLP if cfg.liger_cross_entropy: modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward - - elif cfg.model_config_type == "qwen2": - from liger_kernel.transformers.model.qwen2 import ( - lce_forward as qwen2_lce_forward, - ) - from transformers.models.qwen2 import modeling_qwen2 - - if cfg.liger_rope: - modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm - if cfg.liger_swiglu: - modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP - if cfg.liger_cross_entropy: - modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward - elif cfg.model_config_type == "deepseek_v2": from accelerate import init_empty_weights from transformers import AutoModelForCausalLM @@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin): logging.warning("Fused liger_rope is not supported for DeepseekV2.") if cfg.liger_rms_norm: modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm - if cfg.liger_swiglu: + if cfg.liger_glu_activation: modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward if cfg.liger_cross_entropy: modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward - - elif cfg.model_config_type == "gemma2": - from transformers.models.gemma2 import modeling_gemma2 - - if cfg.liger_rope: - modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_gemma2.Gemma2RMSNorm = partial( - LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" - ) - if cfg.liger_swiglu: - modeling_gemma2.Gemma2MLP = LigerGEGLUMLP - if cfg.liger_cross_entropy: - modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Gemma 2." - ) - - elif cfg.model_config_type == "phi3": - from liger_kernel.transformers.model.phi3 import ( - lce_forward as phi3_lce_forward, - ) - from transformers.models.phi3 import modeling_phi3 - - if cfg.liger_rope: - modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_phi3.Phi3RMSNorm = LigerRMSNorm - if cfg.liger_swiglu: - modeling_phi3.Phi3MLP = LigerSwiGLUMLP - if cfg.liger_cross_entropy: - modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index decdb3775..02ece3143 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -15,9 +15,12 @@ """ Module for handling LIGER input arguments. """ +import logging from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, model_validator + +LOG = logging.getLogger("axolotl.integrations.liger.args") class LigerArgs(BaseModel): @@ -27,6 +30,24 @@ class LigerArgs(BaseModel): liger_rope: Optional[bool] = None liger_rms_norm: Optional[bool] = None + liger_layer_norm: Optional[bool] = None liger_swiglu: Optional[bool] = None + liger_glu_activation: Optional[bool] = None liger_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: Optional[bool] = None + + @model_validator(mode="before") + @classmethod + def check_deprecated_swiglu(cls, data): + if data.get("liger_swiglu") is not None: + if data.get("liger_glu_activation") is not None: + raise ValueError( + "You cannot have both `liger_swiglu` and `liger_glu_activation` set." + ) + + LOG.warning( + "The 'liger_swiglu' argument is deprecated and will be removed in a future release. " + "Please use 'liger_glu_activation' instead." + ) + data["liger_glu_activation"] = data.pop("liger_swiglu") + return data diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py index 4497cebe3..bb4574dff 100644 --- a/tests/e2e/integrations/liger.py +++ b/tests/e2e/integrations/liger.py @@ -1,7 +1,6 @@ """ Simple end-to-end test for Liger integration """ - import unittest from pathlib import Path diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integrations/liger.py b/tests/integrations/liger.py new file mode 100644 index 000000000..61540a57c --- /dev/null +++ b/tests/integrations/liger.py @@ -0,0 +1,80 @@ +""" +config validation tests for swiglu args +""" +# pylint: disable=duplicate-code +import logging +from typing import Optional + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="minimal_base_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + +class BaseValidation: + """ + Base validation module to setup the log capture + """ + + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + +# pylint: disable=too-many-public-methods +class TestValidation(BaseValidation): + """ + Test the validation module for liger + """ + + def test_deprecated_swiglu(self, minimal_cfg): + test_cfg = DictDefault( + { + "liger_swiglu": False, + } + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + updated_cfg = validate_config(test_cfg) + assert ( + "The 'liger_swiglu' argument is deprecated" + in self._caplog.records[0].message + ) + assert updated_cfg.liger_swiglu is None + assert updated_cfg.liger_glu_activations is False + + def test_conflict_swiglu_ligergluactivation(self, minimal_cfg): + test_cfg = DictDefault( + { + "liger_swiglu": False, + "liger_glu_activations": True, + } + | minimal_cfg + ) + + with pytest.raises( + ValueError, + match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*", + ): + validate_config(test_cfg) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8e2955414..a57b6d83e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase): """Verify that processing data from the hub works with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: prepared_path = Path(tmp_dir) / "prepared" + + # make sure prepared_path is empty + shutil.rmtree(prepared_path, ignore_errors=True) + cfg = DictDefault( { "tokenizer_config": "huggyllama/llama-7b",