refactor a bit
This commit is contained in:
@@ -264,8 +264,16 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
|
||||
|
||||
def convert_llama_to_rrt(
|
||||
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
|
||||
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None
|
||||
):
|
||||
if not device:
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
modules_to_recurse = [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
@@ -329,17 +337,10 @@ def convert_llama_to_rrt(
|
||||
|
||||
if __name__ == "__main__":
|
||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
convert_llama_to_rrt(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"/tmp/rrt_model",
|
||||
recurse_layers=4,
|
||||
rank=256,
|
||||
alpha=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
Configuration for Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
model_type = "llama-rrt"
|
||||
recurse_layers: int = 4
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool = True
|
||||
@@ -12,20 +12,10 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager
|
||||
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
|
||||
|
||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
Configuration for Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
model_type = "llama-rrt"
|
||||
recurse_layers: int = 4
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool = True
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
||||
|
||||
Reference in New Issue
Block a user