From de771fcb05554696fc43c928748b790053d70be8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 12:18:41 -0500 Subject: [PATCH] fix convert logger and registration --- src/axolotl/integrations/rrt/__init__.py | 4 ++-- src/axolotl/integrations/rrt/cli/convert.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py index 976897ea6..35d4106b7 100644 --- a/src/axolotl/integrations/rrt/__init__.py +++ b/src/axolotl/integrations/rrt/__init__.py @@ -38,5 +38,5 @@ def register_rrt_model(): AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig) # Register models - AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) - AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) + AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) + AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 8d90db484..3418e8d4a 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -1,4 +1,5 @@ import json +import logging import math import os import re @@ -15,6 +16,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig +logger = logging.getLogger(__name__) def extract_layer_number(key): """Extract layer number from parameter key."""