use autoconfig w rala
This commit is contained in:
@@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path):
|
|||||||
zero_init=cli_args.zero_init,
|
zero_init=cli_args.zero_init,
|
||||||
)
|
)
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
model.config.model_type = "llama-rala"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention
|
from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -18,7 +16,6 @@ class RalaPlugin(BasePlugin):
|
|||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.rala.args.RalaArgs"
|
return "axolotl.integrations.rala.args.RalaArgs"
|
||||||
|
|
||||||
def set_attn_config(self, cfg, model_kwargs, model_config):
|
def pre_model_load(self, cfg):
|
||||||
# if cfg.rala_attention:
|
if cfg.rala_attention:
|
||||||
# model_kwargs["attn_implementation"] = "rala"
|
register_rala_model()
|
||||||
...
|
|
||||||
|
|||||||
@@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig):
|
|||||||
Configuration for LlamaRala model
|
Configuration for LlamaRala model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type = "llama-rala"
|
||||||
softmax_every: int = 6 # every N-th layer applies softmax
|
softmax_every: int = 6 # every N-th layer applies softmax
|
||||||
|
|||||||
@@ -19,11 +19,18 @@ from typing import List, Optional, Tuple, Union, Unpack
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Cache, GenerationMixin, LlamaModel
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
Cache,
|
||||||
|
GenerationMixin,
|
||||||
|
LlamaModel,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LLAMA_ATTENTION_CLASSES,
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
LlamaAttention,
|
|
||||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||||
LlamaLinearScalingRotaryEmbedding,
|
LlamaLinearScalingRotaryEmbedding,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
@@ -329,7 +336,10 @@ class LlamaRalaDecoderLayer(nn.Module):
|
|||||||
if LlamaRalaConfig.is_layer_idx_softmax(
|
if LlamaRalaConfig.is_layer_idx_softmax(
|
||||||
config.num_hidden_layers, layer_idx, config.softmax_every
|
config.num_hidden_layers, layer_idx, config.softmax_every
|
||||||
):
|
):
|
||||||
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
config=config, layer_idx=layer_idx
|
||||||
|
)
|
||||||
|
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||||
else:
|
else:
|
||||||
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
|
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
@@ -594,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_rala_model() -> None:
|
||||||
|
"""
|
||||||
|
Register differential attention components with the transformers library.
|
||||||
|
This function registers the differential attention configurations and model classes
|
||||||
|
with the Auto* classes from `transformers`, making them available through the
|
||||||
|
standard model loading pipeline.
|
||||||
|
"""
|
||||||
|
# Register configs
|
||||||
|
AutoConfig.register("llama-rala", LlamaRalaConfig)
|
||||||
|
|
||||||
|
# Register models
|
||||||
|
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
|
||||||
|
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
|
||||||
|
|
||||||
|
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ from torch import nn
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
from axolotl.integrations.rala import LlamaRALAAttention
|
from axolotl.integrations.rala.auto.llama.modeling_rala import (
|
||||||
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
|
LlamaRALAAttention,
|
||||||
|
LlamaRalaDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user