Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
f1b4030cdd WIP shampoo low bit optimizers 2024-11-08 10:02:10 -05:00
Wing Lian
035e9f9dd7 janky workaround to install FA2 on torch 2.5.1 base image since it takes forever to build (#2022) 2024-11-07 17:54:29 -05:00
Wing Lian
02ce520b7e upgrade liger to 0.4.0 (#1973)
* upgrade liger to 0.3.1

* update docs and example

* skip duplicate code check

* Update src/axolotl/integrations/liger/args.py

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* add logging

* chore: lint

* add test case

* upgrade liger and transformers

* also upgrade accelerate

* use kwargs to support patch release

* make sure prepared path is empty for test

* use transfromers 4.46.1 since 4.46.2 breaks fsdp

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-07 12:53:34 -05:00
14 changed files with 400 additions and 119 deletions

View File

@@ -562,7 +562,8 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
``` ```

View File

@@ -35,3 +35,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
pip3 install flash-attn==2.6.3; \
fi

View File

@@ -9,7 +9,7 @@ strict: false
plugins: plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2 chat_template: deepseek_v2

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_swiglu: true liger_glu_activation: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
strict: false strict: false

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.13.2
transformers==4.46.0 transformers==4.46.1
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.44.1 bitsandbytes==0.44.1
accelerate==1.0.1 accelerate==1.1.0
datasets==3.0.1 datasets==3.0.1
deepspeed==0.15.3 deepspeed==0.15.3
pydantic==2.6.3 pydantic==2.6.3
@@ -34,7 +34,7 @@ tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq>=0.2.5 autoawq>=0.2.5
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.3.0 liger-kernel==0.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1

View File

@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics) return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):

View File

@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training. Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect
import logging import logging
import sys import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):
""" """
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs" return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
if cfg.model_config_type == "llama": if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
from liger_kernel.transformers.model.llama import ( apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
lce_forward as llama_lce_forward, liger_fn_sig = inspect.signature(apply_liger_fn)
) kwargs = {}
from transformers.models.llama import modeling_llama if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if cfg.liger_rope: if "cross_entropy" in liger_fn_sig.parameters:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb kwargs["cross_entropy"] = cfg.liger_cross_entropy
if cfg.liger_rms_norm: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
modeling_llama.LlamaRMSNorm = LigerRMSNorm kwargs[
if cfg.liger_swiglu: "fused_linear_cross_entropy"
modeling_llama.LlamaMLP = LigerSwiGLUMLP ] = cfg.liger_fused_linear_cross_entropy
if cfg.liger_cross_entropy: if "rms_norm" in liger_fn_sig.parameters:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss kwargs["rms_norm"] = cfg.liger_rms_norm
elif cfg.liger_fused_linear_cross_entropy: if "layer_norm" in liger_fn_sig.parameters:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
elif cfg.model_config_type == "mistral": kwargs["geglu"] = cfg.liger_glu_activation
from liger_kernel.transformers.model.mistral import ( elif "swiglu" in liger_fn_sig.parameters:
lce_forward as mistral_lce_forward, kwargs["swiglu"] = cfg.liger_glu_activation
) with zero_only():
from transformers.models.mistral import modeling_mistral LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
) )
if cfg.liger_swiglu: apply_liger_fn(**kwargs)
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba": elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba from transformers.models.jamba import modeling_jamba
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu: if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2": elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.") logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu: if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,9 +15,12 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel): class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

View File

@@ -0,0 +1,250 @@
from typing import Optional
import torch
from torch import Tensor
from torch.distributed._tensor import DTensor
from torch.optim import Optimizer
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
class _ShampooBase(Optimizer):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size,
quantization_bits,
optimizer_state_class,
):
if lr <= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if eps < 0.0:
raise ValueError(f"Invalid eps value: {eps}")
if update_freq < 1:
raise ValueError(f"Invalid update_freq value: {update_freq}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
eps=eps,
update_freq=update_freq,
)
super().__init__(params, defaults)
self.block_size = block_size
self.quantization_bits = quantization_bits
self.optimizer_state_class = optimizer_state_class
def step(self, closure: Optional[callable] = None) -> Optional[float]:
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["momentum_buffer"] = self._new_buffer(grad, True)
state["preconds"] = []
state["inv_preconds"] = []
for dim in grad.size():
state["preconds"].append(
self.optimizer_state_class.zeros(
(dim, dim),
signed=False,
block_size=self.block_size,
device=grad.device,
)
)
state["inv_preconds"].append(
torch.zeros((dim, dim), device=grad.device)
)
state["step"] += 1
beta = group["momentum"]
weight_decay = group["weight_decay"]
lr = group["lr"]
eps = group["eps"]
update_freq = group["update_freq"]
# Apply momentum
if beta > 0:
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
grad = state["momentum_buffer"]
# Apply weight decay
if weight_decay > 0:
grad = grad.add(p.data, alpha=weight_decay)
# Preconditioning
order = grad.ndimension()
original_size = grad.size()
for dim_id, dim in enumerate(grad.size()):
precond = state["preconds"][dim_id]
inv_precond = state["inv_preconds"][dim_id]
# Reshape grad
grad = grad.transpose(0, dim_id).contiguous()
transposed_size = grad.size()
grad = grad.view(dim, -1)
grad_t = grad.t()
# Update preconditioner
precond_fp32 = precond.dequantize()
precond_update = grad @ grad_t
precond_fp32.add_(precond_update)
# Quantize preconditioner back
precond.copy_(precond_fp32)
# Update inverse preconditioner
if state["step"] % update_freq == 0:
inv_precond.copy_(
self._compute_inv_precond(precond_fp32, eps, order)
)
# Precondition grad
if dim_id == order - 1:
# Last dimension
grad = grad_t @ inv_precond
grad = grad.view(original_size)
else:
grad = inv_precond @ grad
grad = grad.view(transposed_size)
# Update parameter
p.data.add_(grad, alpha=-lr)
return loss
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
# Add eps for numerical stability
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
# Compute matrix power
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
return inv_precond
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
# Compute matrix power using SVD
u, s, v = torch.svd(matrix)
s_pow = s.pow(power)
return u @ torch.diag(s_pow) @ v.t()
# bring your own function to create zero-filled subclass
@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
raise NotImplementedError
# follow bitsandbytes, only quantize tensors >= 4096 values
# also wrap subclass in DTensor when needed
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
else:
out = torch.zeros_like(p)
return out
class Shampoo8bit(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=256,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=8,
optimizer_state_class=OptimState8bit,
)
class Shampoo4bit(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=128,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=4,
optimizer_state_class=OptimState4bit,
)
class ShampooFp8(_ShampooBase):
def __init__(
self,
params,
lr=1e-1,
momentum=0.0,
weight_decay=0.0,
eps=1e-4,
update_freq=1,
*,
block_size=256,
):
super().__init__(
params,
lr,
momentum,
weight_decay,
eps,
update_freq,
block_size=block_size,
quantization_bits=8, # FP8 uses 8 bits
optimizer_state_class=OptimStateFp8,
)

View File

@@ -1,7 +1,6 @@
""" """
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
import unittest import unittest
from pathlib import Path from pathlib import Path

View File

View File

@@ -0,0 +1,80 @@
"""
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
}
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
validate_config(test_cfg)

View File

@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
"""Verify that processing data from the hub works with a specific revision""" """Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
# make sure prepared_path is empty
shutil.rmtree(prepared_path, ignore_errors=True)
cfg = DictDefault( cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b",