fix convert logger and registration
This commit is contained in:
@@ -38,5 +38,5 @@ def register_rrt_model():
|
|||||||
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
||||||
|
|
||||||
# Register models
|
# Register models
|
||||||
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||||
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -15,6 +16,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
|||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_layer_number(key):
|
def extract_layer_number(key):
|
||||||
"""Extract layer number from parameter key."""
|
"""Extract layer number from parameter key."""
|
||||||
|
|||||||
Reference in New Issue
Block a user