refactor a bit

This commit is contained in:
Wing Lian
2025-01-21 10:14:16 -05:00
parent b582d340b0
commit 08a4e8a7fb
3 changed files with 23 additions and 19 deletions

View File

@@ -264,8 +264,16 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
def convert_llama_to_rrt(
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None
):
if not device:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
@@ -329,17 +337,10 @@ def convert_llama_to_rrt(
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
convert_llama_to_rrt(
"meta-llama/Llama-3.2-1B",
"/tmp/rrt_model",
recurse_layers=4,
rank=256,
alpha=512,
device=device,
)

View File

@@ -0,0 +1,13 @@
from transformers import LlamaConfig
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True

View File

@@ -12,20 +12,10 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True
class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig):