fix convert logger and registration
This commit is contained in:
@@ -38,5 +38,5 @@ def register_rrt_model():
|
||||
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@@ -15,6 +16,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_layer_number(key):
|
||||
"""Extract layer number from parameter key."""
|
||||
|
||||
Reference in New Issue
Block a user