Liger Kernel integration (#1861)
* add initial plugin support w Liger kernel patches * integrate the input args classes * fix liger plugin and dynamic configuration class * drop untrainable samples and refactor config plugins integration * fix incorrect inputs and circular imports * fix bool comparison * fix for dropping untraibable tokens * fix licensing so liger integration is Apache 2.0 * add jamba support * pylint ignore
This commit is contained in:
@@ -308,10 +308,17 @@ def load_model(
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
|
||||
base_model = cfg.base_model
|
||||
model_type = cfg.type_of_model
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# load any patches from plugins
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(cfg)
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
|
||||
|
||||
Reference in New Issue
Block a user