diff --git a/src/axolotl/integrations/lolcats/__init__.py b/src/axolotl/integrations/lolcats/__init__.py index a815721fa..d9d916c0f 100644 --- a/src/axolotl/integrations/lolcats/__init__.py +++ b/src/axolotl/integrations/lolcats/__init__.py @@ -11,6 +11,8 @@ from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import ( DistillAttentionXentMSETrainer, ) +from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401 + LOG = logging.getLogger("axolotl.integrations.lolcats") diff --git a/src/axolotl/integrations/lolcats/args.py b/src/axolotl/integrations/lolcats/args.py index 0c2bd6561..97d6613e9 100644 --- a/src/axolotl/integrations/lolcats/args.py +++ b/src/axolotl/integrations/lolcats/args.py @@ -2,12 +2,46 @@ Module for handling linear attention input arguments. """ +from typing import Optional + from pydantic import BaseModel +class FeatureMapKwargs(BaseModel): + """Args for feature map""" + + eps: float + mlp: Optional[None] = None + fullspace: bool + + +class LearnedKernelKwargs(BaseModel): + """Args for learned kernel""" + + feature_dim: int + skip_connection: bool + bias: bool + zero_init: bool + + +class AttentionConfig(BaseModel): + """Args for attention config""" + + attention_type: str + feature_map: str + feature_map_kwargs: FeatureMapKwargs + layer_idx: Optional[None] = None + learned_kernel: str + learned_kernel_kwargs: LearnedKernelKwargs + tie_qk_kernels: bool + train_qk: bool + + class LinearAttentionArgs(BaseModel): """ Input args for linear attention """ - attention_config: dict + attention_config: AttentionConfig + + linearize: bool