refactor a bit
This commit is contained in:
@@ -264,8 +264,16 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
|||||||
|
|
||||||
|
|
||||||
def convert_llama_to_rrt(
|
def convert_llama_to_rrt(
|
||||||
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
|
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None
|
||||||
):
|
):
|
||||||
|
if not device:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
modules_to_recurse = [
|
modules_to_recurse = [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
"self_attn.k_proj",
|
"self_attn.k_proj",
|
||||||
@@ -329,17 +337,10 @@ def convert_llama_to_rrt(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = "mps"
|
|
||||||
elif torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
convert_llama_to_rrt(
|
convert_llama_to_rrt(
|
||||||
"meta-llama/Llama-3.2-1B",
|
"meta-llama/Llama-3.2-1B",
|
||||||
"/tmp/rrt_model",
|
"/tmp/rrt_model",
|
||||||
recurse_layers=4,
|
recurse_layers=4,
|
||||||
rank=256,
|
rank=256,
|
||||||
alpha=512,
|
alpha=512,
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||||
|
"""
|
||||||
|
Configuration for Relaxed Recursive Llama.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "llama-rrt"
|
||||||
|
recurse_layers: int = 4
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
use_dora: bool = True
|
||||||
@@ -12,20 +12,10 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager
|
|||||||
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
|
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
|
||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||||
|
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
|
||||||
"""
|
|
||||||
Configuration for Relaxed Recursive Llama.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type = "llama-rrt"
|
|
||||||
recurse_layers: int = 4
|
|
||||||
rank: int
|
|
||||||
alpha: int
|
|
||||||
use_dora: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||||
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
||||||
|
|||||||
Reference in New Issue
Block a user