Liger Kernel integration (#1861)

* add initial plugin support w Liger kernel patches

* integrate the input args classes

* fix liger plugin and dynamic configuration class

* drop untrainable samples and refactor config plugins integration

* fix incorrect inputs and circular imports

* fix bool comparison

* fix for dropping untraibable tokens

* fix licensing so liger integration is Apache 2.0

* add jamba support

* pylint ignore
This commit is contained in:
Wing Lian
2024-08-23 12:21:51 -04:00
committed by GitHub
parent e8ff5d5738
commit 1f686c576c
12 changed files with 1010 additions and 3 deletions

View File

@@ -308,10 +308,17 @@ def load_model(
"""
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.type_of_model
model_config = load_model_config(cfg)
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(cfg)
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit