fix convert logger and registration

This commit is contained in:
Wing Lian
2025-01-20 12:18:41 -05:00
parent f32d429db5
commit de771fcb05
2 changed files with 4 additions and 2 deletions

View File

@@ -38,5 +38,5 @@ def register_rrt_model():
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)

View File

@@ -1,4 +1,5 @@
import json
import logging
import math
import os
import re
@@ -15,6 +16,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
def extract_layer_number(key):
"""Extract layer number from parameter key."""