Compare commits

..

24 Commits

Author SHA1 Message Date
sunny
bbf5158e9c test 2024-11-07 11:06:28 -05:00
sunny
ec70046a2b test 2024-11-07 11:04:33 -05:00
sunny
7fed41550e test 2024-11-07 11:02:54 -05:00
sunny
da3a941bc3 test 2024-11-07 11:00:51 -05:00
sunny
ad3c179a5a test 2024-11-07 10:59:29 -05:00
sunny
15e26b14eb test 2024-11-07 10:54:48 -05:00
sunny
33bbe9b222 test 2024-11-07 10:52:52 -05:00
sunny
1fddf45958 test 2024-11-07 10:46:47 -05:00
Wing Lian
e42e319446 make sure prepared path is empty for test 2024-11-06 10:20:51 -05:00
Wing Lian
613f238e56 use kwargs to support patch release 2024-11-06 09:43:35 -05:00
Wing Lian
6b617a4fd5 also upgrade accelerate 2024-11-06 08:59:52 -05:00
Wing Lian
6ac10de9ef upgrade liger and transformers 2024-11-06 08:53:03 -05:00
Wing Lian
1b8d439441 add test case 2024-11-05 09:23:08 +07:00
Wing Lian
1ed351781a chore: lint 2024-11-05 09:23:08 +07:00
Wing Lian
c2a48c3a1e add logging 2024-11-05 09:23:08 +07:00
Wing Lian
415399b565 Update README.md
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
67c04133f2 Update src/axolotl/integrations/liger/args.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-05 09:23:08 +07:00
Wing Lian
4911d0952f skip duplicate code check 2024-11-05 09:23:08 +07:00
Wing Lian
1d7ab52161 update docs and example 2024-11-05 09:23:08 +07:00
Wing Lian
fcdc6fee8b upgrade liger to 0.3.1 2024-11-05 09:23:08 +07:00
Wing Lian
052a9a79b4 only run the remainder of the gpu test suite if one case passes first (#2009) [skip ci]
* only run the remainder of the gpu test suite if one case passes first

* also reduce the test matrix
2024-10-31 13:45:01 -04:00
Wing Lian
3591bcfaf9 add torch 2.5.1 for base image (#2010) 2024-10-31 13:27:49 -04:00
Wing Lian
dc1de7d81b add retries for load datasets requests failures (#2007) 2024-10-31 13:26:14 -04:00
Chirag Jain
d4dbfa02fe Add plugin manager's callback hooks to training flow (#2006)
* Add plugin manager's callback hooks to training flow

* Use .values() instead of .items()
2024-10-31 12:13:46 -04:00
21 changed files with 384 additions and 740 deletions

View File

@@ -40,7 +40,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -82,13 +82,6 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -72,13 +72,53 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
matrix:
@@ -89,18 +129,6 @@ jobs:
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -562,7 +562,8 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -9,7 +9,7 @@ strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
strict: false

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.0
transformers==4.46.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
accelerate==1.1.0
datasets==3.0.1
deepspeed==0.15.3
pydantic==2.6.3
@@ -34,7 +34,7 @@ tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
liger-kernel==0.4.0
mamba-ssm==1.2.0.post1

View File

@@ -48,6 +48,7 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -435,13 +436,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"soap",
]
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()
@@ -484,25 +479,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif self.args.alternate_optimizer == "soap":
from axolotl.utils.optimizers.soap import SOAP
optim_args = {
"lr": optimizer_kwargs.pop("lr"),
"eps": optimizer_kwargs.pop("eps"),
}
if self.cfg.optim_args:
optim_args.update(self.cfg.optim_args)
optim_args["betas"] = (
self.args.optim_soap_beta1,
self.args.optim_soap_beta2,
)
self.optimizer = SOAP( # pylint: disable=attribute-defined-outside-init
optimizer_grouped_parameters,
**optim_args,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -920,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -1172,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1198,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1248,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1638,12 +1626,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"soap",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1818,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -2027,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):

View File

@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict
class BasePlugin:
@@ -47,7 +48,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg):
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
@@ -63,7 +64,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
@@ -74,7 +75,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -86,7 +87,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
@@ -98,7 +99,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model):
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
@@ -110,7 +111,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer):
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -122,7 +123,9 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -135,7 +138,7 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
@@ -146,8 +149,11 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(self, cfg, trainer):
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
@@ -158,8 +164,9 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model):
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
@@ -171,7 +178,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg):
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
@@ -227,7 +234,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
@@ -237,7 +244,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance
@staticmethod
@@ -265,7 +272,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -277,7 +284,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for plugin in self.plugins.values():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -293,7 +300,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -307,7 +314,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -321,7 +328,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -335,7 +342,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -349,7 +356,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -367,7 +374,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -385,7 +392,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
@@ -401,7 +408,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
@@ -416,5 +423,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

View File

@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,9 +15,12 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -427,7 +427,6 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"soap",
],
]
] = OptimizerNames.ADAMW_HF.value
@@ -440,10 +439,6 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
optim_soap_beta1: Optional[float] = None
optim_soap_beta2: Optional[float] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None

View File

@@ -2,9 +2,11 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 Nikhil Vyas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,475 +0,0 @@
# pylint: skip-file
# Copied from https://github.com/nikhilvyas/SOAP
from itertools import chain
import torch
import torch.optim as optim
# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
class SOAP(optim.Optimizer):
"""
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.003):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
precondition_frequency (`int`, *optional*, defaults to 10):
How often to update the preconditioner.
max_precond_dim (`int`, *optional*, defaults to 10000):
Maximum dimension of the preconditioner.
Set to 10000, so that we exclude most common vocab sizes while including layers.
merge_dims (`bool`, *optional*, defaults to `False`):
Whether or not to merge dimensions of the preconditioner.
precondition_1d (`bool`, *optional*, defaults to `False`):
Whether or not to precondition 1D gradients.
normalize_grads (`bool`, *optional*, defaults to `False`):
Whether or not to normalize gradients per layer.
Helps at large precondition_frequency (~100 in our experiments),
but hurts performance at small precondition_frequency (~10 in our experiments).
data_format (`str`, *optional*, defaults to `channels_first`):
Data format of the input for convolutional layers.
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias correction in Adam.
"""
def __init__(
self,
params,
lr: float = 3e-3,
betas=(0.95, 0.95),
shampoo_beta: float = -1,
eps: float = 1e-8,
weight_decay: float = 0.01,
precondition_frequency: int = 10,
max_precond_dim: int = 10000, #
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
precondition_1d: bool = False,
normalize_grads: bool = False,
data_format: str = "channels_first",
correct_bias: bool = True,
):
defaults = {
"lr": lr,
"betas": betas,
"shampoo_beta": shampoo_beta,
"eps": eps,
"weight_decay": weight_decay,
"precondition_frequency": precondition_frequency,
"max_precond_dim": max_precond_dim,
"merge_dims": merge_dims,
"precondition_1d": precondition_1d,
"normalize_grads": normalize_grads,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)
self._data_format = data_format
def merge_dims(self, grad, max_precond_dim):
"""
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
"""
assert self._data_format in ["channels_first", "channels_last"]
if self._data_format == "channels_last" and grad.dim() == 4:
grad = grad.permute(0, 3, 1, 2)
shape = grad.shape
new_shape = []
curr_shape = 1
for sh in shape:
temp_shape = curr_shape * sh
if temp_shape > max_precond_dim:
if curr_shape > 1:
new_shape.append(curr_shape)
curr_shape = sh
else:
new_shape.append(sh)
curr_shape = 1
else:
curr_shape = temp_shape
if curr_shape > 1 or len(new_shape) == 0:
new_shape.append(curr_shape)
new_grad = grad.reshape(new_shape)
return new_grad
@torch.no_grad()
def step(self):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "step" not in state:
state["step"] = 0
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
if "Q" not in state:
self.init_preconditioner(
grad,
state,
precondition_frequency=group["precondition_frequency"],
precondition_1d=group["precondition_1d"],
shampoo_beta=(
group["shampoo_beta"]
if group["shampoo_beta"] >= 0
else group["betas"][1]
),
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
)
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
grad_projected = self.project(
grad,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_(
grad_projected.square(), alpha=(1.0 - beta2)
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
exp_avg_projected = self.project(
exp_avg,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** (state["step"])
bias_correction2 = 1.0 - beta2 ** (state["step"])
step_size = step_size * (bias_correction2**0.5) / bias_correction1
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
# to the original space
norm_grad = self.project_back(
exp_avg_projected / denom,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
if group["normalize_grads"]:
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
p.add_(norm_grad, alpha=-step_size)
# From AdamW code: Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
# Update is done after the gradient step to avoid using current gradients in the projection.
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
return loss
def init_preconditioner(
self,
grad,
state,
precondition_frequency=10,
shampoo_beta=0.95,
max_precond_dim=10000,
precondition_1d=False,
merge_dims=False,
):
"""
Initializes the preconditioner matrices (L and R in the paper).
"""
state[
"GG"
] = [] # Will hold all the preconditioner matrices (L and R in the paper).
if grad.dim() == 1:
if not precondition_1d or grad.shape[0] > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
)
else:
if merge_dims:
grad = self.merge_dims(grad, max_precond_dim)
for sh in grad.shape:
if sh > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
state["precondition_frequency"] = precondition_frequency
state["shampoo_beta"] = shampoo_beta
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient to the eigenbases of the preconditioner.
"""
original_shape = grad.shape
if merge_dims:
if grad.dim() == 4 and self._data_format == "channels_last":
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [0]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def update_preconditioner(
self,
grad,
state,
max_precond_dim=10000,
merge_dims=False,
precondition_1d=False,
):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state["GG"][0].lerp_(
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
)
else:
if merge_dims:
new_grad = self.merge_dims(grad, max_precond_dim)
for idx, sh in enumerate(new_grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
new_grad,
new_grad,
dims=[
[
*chain(
range(idx), range(idx + 1, len(new_grad.shape))
)
]
]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
else:
for idx, sh in enumerate(grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
if state["Q"] is None:
state["Q"] = self.get_orthogonal_matrix(state["GG"])
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
state["Q"] = self.get_orthogonal_matrix_QR(
state, max_precond_dim, merge_dims
)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient back to the original space.
"""
original_shape = grad.shape
if merge_dims:
if self._data_format == "channels_last" and grad.dim() == 4:
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [1]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def get_orthogonal_matrix(self, mat):
"""
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
"""
matrix = []
for m in mat:
if len(m) == 0:
matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
else:
float_data = True
matrix.append(m.data)
final = []
for m in matrix:
if len(m) == 0:
final.append([])
continue
try:
_, Q = torch.linalg.eigh(
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
except: # pylint: disable=bare-except # noqa: E722
_, Q = torch.linalg.eigh(
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
Q = Q.to(m.dtype)
Q = torch.flip(Q, [1])
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
return final
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
followed by torch.linalg.qr decomposition.
"""
precond_list = state["GG"]
orth_list = state["Q"]
matrix = []
orth_matrix = []
for m, o in zip(precond_list, orth_list):
if len(m) == 0:
matrix.append([])
orth_matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
else:
float_data = True
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
orig_shape = state["exp_avg_sq"].shape
if self._data_format == "channels_last" and len(orig_shape) == 4:
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
if merge_dims:
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
else:
exp_avg_sq = state["exp_avg_sq"]
final = []
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
if len(m) == 0:
final.append([])
continue
est_eig = torch.diag(o.T @ m @ o)
sort_idx = torch.argsort(est_eig, descending=True)
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
o = o[:, sort_idx]
power_iter = m @ o
Q, _ = torch.linalg.qr(power_iter)
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
if merge_dims:
if self._data_format == "channels_last" and len(orig_shape) == 4:
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
state["exp_avg_sq"] = exp_avg_sq
return final

76
test.yml Normal file
View File

@@ -0,0 +1,76 @@
base_model: JackFram/llama-68m
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.5
output_dir: ./outputs/out
sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>

View File

@@ -1,7 +1,6 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
@@ -64,6 +63,51 @@ class LigerIntegrationTestCase(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
def test_llama_wo_flce2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
def test_llama_w_flce(self, temp_dir):
cfg = DictDefault(

View File

@@ -65,44 +65,3 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_soap(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "soap",
"optim_soap_beta1": 0.95,
"optim_soap_beta2": 0.95,
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -0,0 +1,80 @@
"""
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
}
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
validate_config(test_cfg)

View File

@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
# make sure prepared_path is empty
shutil.rmtree(prepared_path, ignore_errors=True)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",