diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py index 976897ea6..35d4106b7 100644 --- a/src/axolotl/integrations/rrt/__init__.py +++ b/src/axolotl/integrations/rrt/__init__.py @@ -38,5 +38,5 @@ def register_rrt_model(): AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig) # Register models - AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) - AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) + AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) + AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 8d90db484..3418e8d4a 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -1,4 +1,5 @@ import json +import logging import math import os import re @@ -15,6 +16,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig +logger = logging.getLogger(__name__) def extract_layer_number(key): """Extract layer number from parameter key."""