feat: integrate new modelling into cli

This commit is contained in:
NanoCode012
2025-02-04 19:46:05 +07:00
parent 1fb8d86396
commit 2bc7833a4e
3 changed files with 54 additions and 10 deletions

View File

@@ -16,11 +16,14 @@ from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import PluginManager
from axolotl.integrations.lolcats.linearize_attention import (
remove_base_attention,
toggle_attention,
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
LinearLlamaConfig,
)
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
LinearLlamaForCausalLM,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.trainer import setup_trainer
LOG = logging.getLogger(__name__)
@@ -42,11 +45,15 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# convert attention
from axolotl.integrations.lolcats.linearize_attention import convert_attention
# load config
base_config = load_model_config(cfg)
model = convert_attention(
model, cfg.attention_config, train_attention=True, remove_base_attn=True
# convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama(
base_config, cfg.attention_config
)
model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True
)
# Get datasets
@@ -56,7 +63,7 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
total_num_steps = dataset_meta.total_num_steps
# toggle attention to be trainable
model = toggle_attention(model, train=True)
model.toggle_attention(train=True)
# Setup trainer
trainer = setup_trainer(
@@ -73,8 +80,8 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# drop base_attention + remove training attn
model = toggle_attention(model, train=False)
model = remove_base_attention(model)
model.toggle_attention(train=False)
model.remove_base_attention()
# NOTE: If in peft mode, consider whether to auto-merge

View File

@@ -62,3 +62,18 @@ class LinearLlamaConfig(LlamaConfig):
# Set default attention config if none provided
self.attention_config = attention_config
@classmethod
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
"""
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
Args:
llama_config (:class:`~transformers.LlamaConfig`):
The LlamaConfig to inherit from.
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
"""
return cls(attention_config=attention_config, **llama_config.to_dict())

View File

@@ -75,6 +75,10 @@ class LinearLlamaModel(LlamaModel):
class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Linear LLaMA model for causal language modeling.
"""
def __init__(self, config):
super().__init__(config)
self.model = LinearLlamaModel(config)
@@ -113,3 +117,21 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
)
return new_model
def toggle_attention(self, train: bool = True):
"""
Toggle attention to be trainable or not
"""
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
toggle_attention(self.model, train=train)
def remove_base_attention(self):
"""
Remove base attention after distillation
"""
from axolotl.integrations.lolcats.linearize_attention import (
remove_base_attention,
)
remove_base_attention(self.model)