feat: integrate new modelling into cli
This commit is contained in:
@@ -16,11 +16,14 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.lolcats.linearize_attention import (
|
||||
remove_base_attention,
|
||||
toggle_attention,
|
||||
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||
LinearLlamaConfig,
|
||||
)
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
LinearLlamaForCausalLM,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -42,11 +45,15 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# convert attention
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
# load config
|
||||
base_config = load_model_config(cfg)
|
||||
|
||||
model = convert_attention(
|
||||
model, cfg.attention_config, train_attention=True, remove_base_attn=True
|
||||
# convert to linear llama
|
||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||
base_config, cfg.attention_config
|
||||
)
|
||||
model = LinearLlamaForCausalLM.from_llama(
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
@@ -56,7 +63,7 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# toggle attention to be trainable
|
||||
model = toggle_attention(model, train=True)
|
||||
model.toggle_attention(train=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = setup_trainer(
|
||||
@@ -73,8 +80,8 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# drop base_attention + remove training attn
|
||||
model = toggle_attention(model, train=False)
|
||||
model = remove_base_attention(model)
|
||||
model.toggle_attention(train=False)
|
||||
model.remove_base_attention()
|
||||
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
|
||||
@@ -62,3 +62,18 @@ class LinearLlamaConfig(LlamaConfig):
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config
|
||||
|
||||
@classmethod
|
||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||
"""
|
||||
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||
|
||||
Args:
|
||||
llama_config (:class:`~transformers.LlamaConfig`):
|
||||
The LlamaConfig to inherit from.
|
||||
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
"""
|
||||
|
||||
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||
|
||||
@@ -75,6 +75,10 @@ class LinearLlamaModel(LlamaModel):
|
||||
|
||||
|
||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Linear LLaMA model for causal language modeling.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LinearLlamaModel(config)
|
||||
@@ -113,3 +117,21 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
)
|
||||
|
||||
return new_model
|
||||
|
||||
def toggle_attention(self, train: bool = True):
|
||||
"""
|
||||
Toggle attention to be trainable or not
|
||||
"""
|
||||
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
|
||||
|
||||
toggle_attention(self.model, train=train)
|
||||
|
||||
def remove_base_attention(self):
|
||||
"""
|
||||
Remove base attention after distillation
|
||||
"""
|
||||
from axolotl.integrations.lolcats.linearize_attention import (
|
||||
remove_base_attention,
|
||||
)
|
||||
|
||||
remove_base_attention(self.model)
|
||||
|
||||
Reference in New Issue
Block a user