feat: integrate new modelling into cli
This commit is contained in:
@@ -16,11 +16,14 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.lolcats.linearize_attention import (
|
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||||
remove_base_attention,
|
LinearLlamaConfig,
|
||||||
toggle_attention,
|
)
|
||||||
|
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||||
|
LinearLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -42,11 +45,15 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
# load model
|
# load model
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
# convert attention
|
# load config
|
||||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
base_config = load_model_config(cfg)
|
||||||
|
|
||||||
model = convert_attention(
|
# convert to linear llama
|
||||||
model, cfg.attention_config, train_attention=True, remove_base_attn=True
|
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||||
|
base_config, cfg.attention_config
|
||||||
|
)
|
||||||
|
model = LinearLlamaForCausalLM.from_llama(
|
||||||
|
model, config=linear_llama_config, train_attention=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
@@ -56,7 +63,7 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
# toggle attention to be trainable
|
# toggle attention to be trainable
|
||||||
model = toggle_attention(model, train=True)
|
model.toggle_attention(train=True)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
@@ -73,8 +80,8 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||||
|
|
||||||
# drop base_attention + remove training attn
|
# drop base_attention + remove training attn
|
||||||
model = toggle_attention(model, train=False)
|
model.toggle_attention(train=False)
|
||||||
model = remove_base_attention(model)
|
model.remove_base_attention()
|
||||||
|
|
||||||
# NOTE: If in peft mode, consider whether to auto-merge
|
# NOTE: If in peft mode, consider whether to auto-merge
|
||||||
|
|
||||||
|
|||||||
@@ -62,3 +62,18 @@ class LinearLlamaConfig(LlamaConfig):
|
|||||||
|
|
||||||
# Set default attention config if none provided
|
# Set default attention config if none provided
|
||||||
self.attention_config = attention_config
|
self.attention_config = attention_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||||
|
"""
|
||||||
|
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llama_config (:class:`~transformers.LlamaConfig`):
|
||||||
|
The LlamaConfig to inherit from.
|
||||||
|
|
||||||
|
attention_config (`dict`):
|
||||||
|
Dictionary containing the configuration for linear attention mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||||
|
|||||||
@@ -75,6 +75,10 @@ class LinearLlamaModel(LlamaModel):
|
|||||||
|
|
||||||
|
|
||||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
"""
|
||||||
|
Linear LLaMA model for causal language modeling.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = LinearLlamaModel(config)
|
self.model = LinearLlamaModel(config)
|
||||||
@@ -113,3 +117,21 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
def toggle_attention(self, train: bool = True):
|
||||||
|
"""
|
||||||
|
Toggle attention to be trainable or not
|
||||||
|
"""
|
||||||
|
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
|
||||||
|
|
||||||
|
toggle_attention(self.model, train=train)
|
||||||
|
|
||||||
|
def remove_base_attention(self):
|
||||||
|
"""
|
||||||
|
Remove base attention after distillation
|
||||||
|
"""
|
||||||
|
from axolotl.integrations.lolcats.linearize_attention import (
|
||||||
|
remove_base_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
remove_base_attention(self.model)
|
||||||
|
|||||||
Reference in New Issue
Block a user