Liger Kernel integration (#1861)

* add initial plugin support w Liger kernel patches

* integrate the input args classes

* fix liger plugin and dynamic configuration class

* drop untrainable samples and refactor config plugins integration

* fix incorrect inputs and circular imports

* fix bool comparison

* fix for dropping untraibable tokens

* fix licensing so liger integration is Apache 2.0

* add jamba support

* pylint ignore
This commit is contained in:
Wing Lian
2024-08-23 12:21:51 -04:00
committed by GitHub
parent e8ff5d5738
commit 1f686c576c
12 changed files with 1010 additions and 3 deletions

View File

@@ -8,11 +8,14 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
from axolotl.utils.config.models.input.v0_4_1 import (
SUPPORTED_METRICS,
AxolotlConfigWCapabilities,
AxolotlInputConfig,
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -207,6 +210,15 @@ def normalize_cfg_datasets(cfg):
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
if cfg.plugins:
(
AxolotlConfigWCapabilities, # pylint: disable=invalid-name
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities:
return DictDefault(
dict(