Compare commits
7 Commits
soap-optim
...
shampoo-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1b4030cdd | ||
|
|
035e9f9dd7 | ||
|
|
02ce520b7e | ||
|
|
052a9a79b4 | ||
|
|
3591bcfaf9 | ||
|
|
dc1de7d81b | ||
|
|
d4dbfa02fe |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
7
.github/workflows/tests-nightly.yml
vendored
7
.github/workflows/tests-nightly.yml
vendored
@@ -82,13 +82,6 @@ jobs:
|
|||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: mamba-ssm
|
axolotl_extras: mamba-ssm
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras: mamba-ssm
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
54
.github/workflows/tests.yml
vendored
54
.github/workflows/tests.yml
vendored
@@ -72,13 +72,53 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests-1st:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==0.63.64 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.tests
|
||||||
|
|
||||||
|
docker-e2e-tests:
|
||||||
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 90
|
||||||
|
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -89,18 +129,6 @@ jobs:
|
|||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: mamba-ssm
|
axolotl_extras: mamba-ssm
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras: mamba-ssm
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.4.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
@@ -562,7 +562,8 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_swiglu: true
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -35,3 +35,7 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
|
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
||||||
|
pip3 install flash-attn==2.6.3; \
|
||||||
|
fi
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ strict: false
|
|||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_swiglu: true
|
liger_glu_activation: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
chat_template: deepseek_v2
|
chat_template: deepseek_v2
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_swiglu: true
|
liger_glu_activation: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.0
|
transformers==4.46.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.0.1
|
accelerate==1.1.0
|
||||||
datasets==3.0.1
|
datasets==3.0.1
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
@@ -34,7 +34,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.3.0
|
liger-kernel==0.4.0
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ from trl import (
|
|||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||||
|
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -895,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, **kwargs):
|
||||||
# make sure the checkpoint dir exists, since trainer is flakey
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
@@ -1147,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
def get_callbacks(self) -> List[TrainerCallback]:
|
def get_callbacks(self) -> List[TrainerCallback]:
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
@@ -1173,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
"""
|
"""
|
||||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||||
"""
|
"""
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||||
|
)
|
||||||
|
return callbacks
|
||||||
|
|
||||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||||
# TODO
|
# TODO
|
||||||
@@ -1223,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = []
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||||
LogPredictionCallback = log_prediction_callback_factory(
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
trainer, self.tokenizer, "wandb"
|
trainer, self.tokenizer, "wandb"
|
||||||
@@ -1791,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = []
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build_training_arguments(self, total_num_steps):
|
def build_training_arguments(self, total_num_steps):
|
||||||
@@ -2000,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = []
|
callbacks = super().get_callbacks()
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = []
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
|
|||||||
|
|
||||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||||
"""
|
"""
|
||||||
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
@@ -47,7 +48,7 @@ class BasePlugin:
|
|||||||
Initializes the BasePlugin.
|
Initializes the BasePlugin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register(self, cfg):
|
def register(self, cfg): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Registers the plugin with the given configuration.
|
Registers the plugin with the given configuration.
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ class BasePlugin:
|
|||||||
Returns a pydantic model for the plugin's input arguments.
|
Returns a pydantic model for the plugin's input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions before the model is loaded.
|
Performs actions before the model is loaded.
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_model_load(self, cfg, model):
|
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions after the model is loaded.
|
Performs actions after the model is loaded.
|
||||||
|
|
||||||
@@ -86,7 +87,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_lora_load(self, cfg, model):
|
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions before LoRA weights are loaded.
|
Performs actions before LoRA weights are loaded.
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_lora_load(self, cfg, model):
|
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions after LoRA weights are loaded.
|
Performs actions after LoRA weights are loaded.
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer):
|
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Creates and returns an optimizer for training.
|
Creates and returns an optimizer for training.
|
||||||
|
|
||||||
@@ -122,7 +123,9 @@ class BasePlugin:
|
|||||||
object: The created optimizer.
|
object: The created optimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
def create_lr_scheduler(
|
||||||
|
self, cfg, trainer, optimizer
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Creates and returns a learning rate scheduler.
|
Creates and returns a learning rate scheduler.
|
||||||
|
|
||||||
@@ -135,7 +138,7 @@ class BasePlugin:
|
|||||||
object: The created learning rate scheduler.
|
object: The created learning rate scheduler.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model):
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
Adds callbacks to the trainer before training.
|
||||||
|
|
||||||
@@ -146,8 +149,11 @@ class BasePlugin:
|
|||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(
|
||||||
|
self, cfg, trainer
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after training.
|
||||||
|
|
||||||
@@ -158,8 +164,9 @@ class BasePlugin:
|
|||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
def post_train(self, cfg, model):
|
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions after training is complete.
|
Performs actions after training is complete.
|
||||||
|
|
||||||
@@ -171,7 +178,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions after training is complete and the model is unloaded.
|
Performs actions after training is complete and the model is unloaded.
|
||||||
|
|
||||||
@@ -227,7 +234,7 @@ class PluginManager:
|
|||||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
plugins: List[BasePlugin] = []
|
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
@@ -237,7 +244,7 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||||
cls._instance.plugins: List[BasePlugin] = []
|
cls._instance.plugins = collections.OrderedDict()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -265,7 +272,7 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins.append(plugin)
|
self.plugins[plugin_name] = plugin
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
|
||||||
@@ -277,7 +284,7 @@ class PluginManager:
|
|||||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||||
"""
|
"""
|
||||||
input_args = []
|
input_args = []
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
input_args_from_plugin = plugin.get_input_args()
|
input_args_from_plugin = plugin.get_input_args()
|
||||||
if input_args_from_plugin is not None:
|
if input_args_from_plugin is not None:
|
||||||
input_args.append(input_args_from_plugin)
|
input_args.append(input_args_from_plugin)
|
||||||
@@ -293,7 +300,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
plugin.pre_model_load(cfg)
|
plugin.pre_model_load(cfg)
|
||||||
|
|
||||||
def post_model_load(self, cfg, model):
|
def post_model_load(self, cfg, model):
|
||||||
@@ -307,7 +314,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
plugin.post_model_load(cfg, model)
|
plugin.post_model_load(cfg, model)
|
||||||
|
|
||||||
def pre_lora_load(self, cfg, model):
|
def pre_lora_load(self, cfg, model):
|
||||||
@@ -321,7 +328,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
plugin.pre_lora_load(cfg, model)
|
plugin.pre_lora_load(cfg, model)
|
||||||
|
|
||||||
def post_lora_load(self, cfg, model):
|
def post_lora_load(self, cfg, model):
|
||||||
@@ -335,7 +342,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
plugin.post_lora_load(cfg, model)
|
plugin.post_lora_load(cfg, model)
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer):
|
def create_optimizer(self, cfg, trainer):
|
||||||
@@ -349,7 +356,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
object: The created optimizer, or None if none was found.
|
object: The created optimizer, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
return optimizer
|
return optimizer
|
||||||
@@ -367,7 +374,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
object: The created learning rate scheduler, or None if none was found.
|
object: The created learning rate scheduler, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
return scheduler
|
return scheduler
|
||||||
@@ -385,7 +392,7 @@ class PluginManager:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -401,7 +408,7 @@ class PluginManager:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -416,5 +423,5 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|||||||
@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from ...utils.distributed import zero_only
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
"""
|
"""
|
||||||
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
if cfg.model_config_type == "llama":
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
from liger_kernel.transformers.model.llama import (
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
lce_forward as llama_lce_forward,
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
)
|
kwargs = {}
|
||||||
from transformers.models.llama import modeling_llama
|
if "rope" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rope"] = cfg.liger_rope
|
||||||
if cfg.liger_rope:
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
if cfg.liger_rms_norm:
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
kwargs[
|
||||||
if cfg.liger_swiglu:
|
"fused_linear_cross_entropy"
|
||||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
] = cfg.liger_fused_linear_cross_entropy
|
||||||
if cfg.liger_cross_entropy:
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||||
|
if "geglu" in liger_fn_sig.parameters:
|
||||||
elif cfg.model_config_type == "mistral":
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
from liger_kernel.transformers.model.mistral import (
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
lce_forward as mistral_lce_forward,
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
)
|
with zero_only():
|
||||||
from transformers.models.mistral import modeling_mistral
|
LOG.info(
|
||||||
|
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "gemma":
|
|
||||||
from liger_kernel.transformers.model.gemma import (
|
|
||||||
lce_forward as gemma_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_gemma.GemmaRMSNorm = partial(
|
|
||||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
||||||
)
|
)
|
||||||
if cfg.liger_swiglu:
|
apply_liger_fn(**kwargs)
|
||||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "jamba":
|
elif cfg.model_config_type == "jamba":
|
||||||
from transformers.models.jamba import modeling_jamba
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|
||||||
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "qwen2":
|
|
||||||
from liger_kernel.transformers.model.qwen2 import (
|
|
||||||
lce_forward as qwen2_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2 import modeling_qwen2
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "gemma2":
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
|
||||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
|
||||||
)
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
logging.warning(
|
|
||||||
"Fused linear cross entropy is not supported for Gemma 2."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif cfg.model_config_type == "phi3":
|
|
||||||
from liger_kernel.transformers.model.phi3 import (
|
|
||||||
lce_forward as phi3_lce_forward,
|
|
||||||
)
|
|
||||||
from transformers.models.phi3 import modeling_phi3
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_swiglu:
|
|
||||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
||||||
|
|||||||
@@ -15,9 +15,12 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||||
|
|
||||||
|
|
||||||
class LigerArgs(BaseModel):
|
class LigerArgs(BaseModel):
|
||||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: Optional[bool] = None
|
liger_rope: Optional[bool] = None
|
||||||
liger_rms_norm: Optional[bool] = None
|
liger_rms_norm: Optional[bool] = None
|
||||||
|
liger_layer_norm: Optional[bool] = None
|
||||||
liger_swiglu: Optional[bool] = None
|
liger_swiglu: Optional[bool] = None
|
||||||
|
liger_glu_activation: Optional[bool] = None
|
||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_deprecated_swiglu(cls, data):
|
||||||
|
if data.get("liger_swiglu") is not None:
|
||||||
|
if data.get("liger_glu_activation") is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||||
|
"Please use 'liger_glu_activation' instead."
|
||||||
|
)
|
||||||
|
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||||
|
return data
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
) as exc:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
|
|||||||
0
src/axolotl/utils/optimizers/__init__.py
Normal file
0
src/axolotl/utils/optimizers/__init__.py
Normal file
250
src/axolotl/utils/optimizers/shampoo.py
Normal file
250
src/axolotl/utils/optimizers/shampoo.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
||||||
|
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
||||||
|
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
||||||
|
|
||||||
|
|
||||||
|
class _ShampooBase(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-1,
|
||||||
|
momentum=0.0,
|
||||||
|
weight_decay=0.0,
|
||||||
|
eps=1e-4,
|
||||||
|
update_freq=1,
|
||||||
|
*,
|
||||||
|
block_size,
|
||||||
|
quantization_bits,
|
||||||
|
optimizer_state_class,
|
||||||
|
):
|
||||||
|
if lr <= 0.0:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if momentum < 0.0:
|
||||||
|
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||||
|
if weight_decay < 0.0:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
if eps < 0.0:
|
||||||
|
raise ValueError(f"Invalid eps value: {eps}")
|
||||||
|
if update_freq < 1:
|
||||||
|
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
eps=eps,
|
||||||
|
update_freq=update_freq,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
self.block_size = block_size
|
||||||
|
self.quantization_bits = quantization_bits
|
||||||
|
self.optimizer_state_class = optimizer_state_class
|
||||||
|
|
||||||
|
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad.data
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state["step"] = 0
|
||||||
|
state["momentum_buffer"] = self._new_buffer(grad, True)
|
||||||
|
state["preconds"] = []
|
||||||
|
state["inv_preconds"] = []
|
||||||
|
for dim in grad.size():
|
||||||
|
state["preconds"].append(
|
||||||
|
self.optimizer_state_class.zeros(
|
||||||
|
(dim, dim),
|
||||||
|
signed=False,
|
||||||
|
block_size=self.block_size,
|
||||||
|
device=grad.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
state["inv_preconds"].append(
|
||||||
|
torch.zeros((dim, dim), device=grad.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
beta = group["momentum"]
|
||||||
|
weight_decay = group["weight_decay"]
|
||||||
|
lr = group["lr"]
|
||||||
|
eps = group["eps"]
|
||||||
|
update_freq = group["update_freq"]
|
||||||
|
|
||||||
|
# Apply momentum
|
||||||
|
if beta > 0:
|
||||||
|
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
||||||
|
grad = state["momentum_buffer"]
|
||||||
|
|
||||||
|
# Apply weight decay
|
||||||
|
if weight_decay > 0:
|
||||||
|
grad = grad.add(p.data, alpha=weight_decay)
|
||||||
|
|
||||||
|
# Preconditioning
|
||||||
|
order = grad.ndimension()
|
||||||
|
original_size = grad.size()
|
||||||
|
for dim_id, dim in enumerate(grad.size()):
|
||||||
|
precond = state["preconds"][dim_id]
|
||||||
|
inv_precond = state["inv_preconds"][dim_id]
|
||||||
|
|
||||||
|
# Reshape grad
|
||||||
|
grad = grad.transpose(0, dim_id).contiguous()
|
||||||
|
transposed_size = grad.size()
|
||||||
|
grad = grad.view(dim, -1)
|
||||||
|
|
||||||
|
grad_t = grad.t()
|
||||||
|
|
||||||
|
# Update preconditioner
|
||||||
|
precond_fp32 = precond.dequantize()
|
||||||
|
precond_update = grad @ grad_t
|
||||||
|
precond_fp32.add_(precond_update)
|
||||||
|
|
||||||
|
# Quantize preconditioner back
|
||||||
|
precond.copy_(precond_fp32)
|
||||||
|
|
||||||
|
# Update inverse preconditioner
|
||||||
|
if state["step"] % update_freq == 0:
|
||||||
|
inv_precond.copy_(
|
||||||
|
self._compute_inv_precond(precond_fp32, eps, order)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precondition grad
|
||||||
|
if dim_id == order - 1:
|
||||||
|
# Last dimension
|
||||||
|
grad = grad_t @ inv_precond
|
||||||
|
grad = grad.view(original_size)
|
||||||
|
else:
|
||||||
|
grad = inv_precond @ grad
|
||||||
|
grad = grad.view(transposed_size)
|
||||||
|
|
||||||
|
# Update parameter
|
||||||
|
p.data.add_(grad, alpha=-lr)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
||||||
|
# Add eps for numerical stability
|
||||||
|
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
||||||
|
|
||||||
|
# Compute matrix power
|
||||||
|
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
||||||
|
|
||||||
|
return inv_precond
|
||||||
|
|
||||||
|
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
||||||
|
# Compute matrix power using SVD
|
||||||
|
u, s, v = torch.svd(matrix)
|
||||||
|
s_pow = s.pow(power)
|
||||||
|
return u @ torch.diag(s_pow) @ v.t()
|
||||||
|
|
||||||
|
# bring your own function to create zero-filled subclass
|
||||||
|
@staticmethod
|
||||||
|
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# follow bitsandbytes, only quantize tensors >= 4096 values
|
||||||
|
# also wrap subclass in DTensor when needed
|
||||||
|
def _new_buffer(self, p: Tensor, signed: bool):
|
||||||
|
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
||||||
|
if isinstance(p, DTensor):
|
||||||
|
out = DTensor.from_local(
|
||||||
|
local_tensor=self._subclass_zeros(
|
||||||
|
p.to_local(), signed, self.block_size
|
||||||
|
),
|
||||||
|
device_mesh=p.device_mesh,
|
||||||
|
placements=p.placements,
|
||||||
|
run_check=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = self._subclass_zeros(p, signed, self.block_size)
|
||||||
|
else:
|
||||||
|
out = torch.zeros_like(p)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Shampoo8bit(_ShampooBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-1,
|
||||||
|
momentum=0.0,
|
||||||
|
weight_decay=0.0,
|
||||||
|
eps=1e-4,
|
||||||
|
update_freq=1,
|
||||||
|
*,
|
||||||
|
block_size=256,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
lr,
|
||||||
|
momentum,
|
||||||
|
weight_decay,
|
||||||
|
eps,
|
||||||
|
update_freq,
|
||||||
|
block_size=block_size,
|
||||||
|
quantization_bits=8,
|
||||||
|
optimizer_state_class=OptimState8bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Shampoo4bit(_ShampooBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-1,
|
||||||
|
momentum=0.0,
|
||||||
|
weight_decay=0.0,
|
||||||
|
eps=1e-4,
|
||||||
|
update_freq=1,
|
||||||
|
*,
|
||||||
|
block_size=128,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
lr,
|
||||||
|
momentum,
|
||||||
|
weight_decay,
|
||||||
|
eps,
|
||||||
|
update_freq,
|
||||||
|
block_size=block_size,
|
||||||
|
quantization_bits=4,
|
||||||
|
optimizer_state_class=OptimState4bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShampooFp8(_ShampooBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-1,
|
||||||
|
momentum=0.0,
|
||||||
|
weight_decay=0.0,
|
||||||
|
eps=1e-4,
|
||||||
|
update_freq=1,
|
||||||
|
*,
|
||||||
|
block_size=256,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
lr,
|
||||||
|
momentum,
|
||||||
|
weight_decay,
|
||||||
|
eps,
|
||||||
|
update_freq,
|
||||||
|
block_size=block_size,
|
||||||
|
quantization_bits=8, # FP8 uses 8 bits
|
||||||
|
optimizer_state_class=OptimStateFp8,
|
||||||
|
)
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
0
tests/integrations/__init__.py
Normal file
0
tests/integrations/__init__.py
Normal file
80
tests/integrations/liger.py
Normal file
80
tests/integrations/liger.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
config validation tests for swiglu args
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="minimal_base_cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseValidation:
|
||||||
|
"""
|
||||||
|
Base validation module to setup the log capture
|
||||||
|
"""
|
||||||
|
|
||||||
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, caplog):
|
||||||
|
self._caplog = caplog
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-public-methods
|
||||||
|
class TestValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
Test the validation module for liger
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_deprecated_swiglu(self, minimal_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"liger_swiglu": False,
|
||||||
|
}
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
updated_cfg = validate_config(test_cfg)
|
||||||
|
assert (
|
||||||
|
"The 'liger_swiglu' argument is deprecated"
|
||||||
|
in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
assert updated_cfg.liger_swiglu is None
|
||||||
|
assert updated_cfg.liger_glu_activations is False
|
||||||
|
|
||||||
|
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"liger_swiglu": False,
|
||||||
|
"liger_glu_activations": True,
|
||||||
|
}
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||||
|
):
|
||||||
|
validate_config(test_cfg)
|
||||||
@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
|
||||||
|
# make sure prepared_path is empty
|
||||||
|
shutil.rmtree(prepared_path, ignore_errors=True)
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
|||||||
Reference in New Issue
Block a user