plugin implementation
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -186,3 +186,6 @@ out/
|
|||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
# symlinked to axolotl-artifacts in docker containers
|
||||||
|
outputs
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
metric,training,validation
|
|
||||||
loss,1.8773103952407837,1.915901780128479
|
|
||||||
model_preparation_time,0.0051,0.0051
|
|
||||||
runtime,89.7635,8.9565
|
|
||||||
samples_per_second,20.053,22.33
|
|
||||||
steps_per_second,20.053,22.33
|
|
||||||
|
@@ -3,7 +3,7 @@ CLI to run training on a model
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
|
|||||||
LOG = logging.getLogger("axolotl.cli.evaluate")
|
LOG = logging.getLogger("axolotl.cli.evaluate")
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg, cli_args) -> None:
|
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers import HfArgumentParser
|
|||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.integrations.differential_transformer.convert import (
|
from axolotl.integrations.differential_transformer.convert import (
|
||||||
convert_to_diff_attention,
|
convert_to_differential_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -79,7 +79,7 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
|||||||
# Convert attention
|
# Convert attention
|
||||||
LOG.info("Converting to differential attention...")
|
LOG.info("Converting to differential attention...")
|
||||||
try:
|
try:
|
||||||
model = convert_to_diff_attention(
|
model = convert_to_differential_attention(
|
||||||
model=model,
|
model=model,
|
||||||
zero_init=cli_args.zero_init,
|
zero_init=cli_args.zero_init,
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
@@ -111,7 +111,10 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
|||||||
data = yaml.safe_load(file) or {}
|
data = yaml.safe_load(file) or {}
|
||||||
|
|
||||||
data["base_model"] = cfg.output_dir
|
data["base_model"] = cfg.output_dir
|
||||||
data["diff_attention"] = True
|
data["differential_attention"] = True
|
||||||
|
data["plugins"] = [
|
||||||
|
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
|
||||||
|
]
|
||||||
|
|
||||||
with open(output_config_path, "w", encoding="utf-8") as file:
|
with open(output_config_path, "w", encoding="utf-8") as file:
|
||||||
yaml.dump(data, file)
|
yaml.dump(data, file)
|
||||||
|
|||||||
@@ -43,10 +43,12 @@ def merge_input_args():
|
|||||||
input_args: List[str] = plugin_manager.get_input_args()
|
input_args: List[str] = plugin_manager.get_input_args()
|
||||||
plugin_classes = []
|
plugin_classes = []
|
||||||
dynamic_input = ""
|
dynamic_input = ""
|
||||||
|
|
||||||
for plugin_args in input_args:
|
for plugin_args in input_args:
|
||||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||||
plugin_classes.append(plugin_cls)
|
plugin_classes.append(plugin_cls)
|
||||||
|
|
||||||
if dynamic_input:
|
if dynamic_input:
|
||||||
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
@@ -62,4 +64,5 @@ def merge_input_args():
|
|||||||
"AxolotlConfigWCapabilities"
|
"AxolotlConfigWCapabilities"
|
||||||
]
|
]
|
||||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||||
|
|
||||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||||
|
|||||||
10
src/axolotl/integrations/differential_transformer/README.md
Normal file
10
src/axolotl/integrations/differential_transformer/README.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Differential Transformer
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
|
||||||
|
|
||||||
|
differential_attention: true
|
||||||
|
```
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Definition of differential transformer plugin."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialTransformerPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for differential transformer integration with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
"""Apply differential attention patch before model loading if enabled."""
|
||||||
|
if cfg.differential_attention:
|
||||||
|
from axolotl.monkeypatch.attention.differential import (
|
||||||
|
patch_llama_attention_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_llama_attention_classes()
|
||||||
|
|||||||
14
src/axolotl/integrations/differential_transformer/args.py
Normal file
14
src/axolotl/integrations/differential_transformer/args.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Module for handling differential transfomer input arguments."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialTransformerArgs(BaseModel):
|
||||||
|
"""Input args for differential transformer."""
|
||||||
|
|
||||||
|
differential_attention: Optional[bool] = None
|
||||||
@@ -80,7 +80,7 @@ def copy_attention_weights(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_diff_attention(
|
def convert_to_differential_attention(
|
||||||
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
|
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||||
|
|||||||
@@ -727,8 +727,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: Optional[bool] = None
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
diff_attention: Optional[bool] = None
|
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
|||||||
@@ -444,13 +444,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mistral_cross_entropy()
|
patch_mistral_cross_entropy()
|
||||||
|
|
||||||
if self.cfg.diff_attention:
|
|
||||||
from axolotl.monkeypatch.attention.differential import (
|
|
||||||
patch_llama_attention_classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_llama_attention_classes()
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -721,7 +714,7 @@ class ModelLoader:
|
|||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.cfg.diff_attention:
|
if self.cfg.differential_attention:
|
||||||
self.model_kwargs[
|
self.model_kwargs[
|
||||||
"attn_implementation"
|
"attn_implementation"
|
||||||
] = "differential_flash_attention_2"
|
] = "differential_flash_attention_2"
|
||||||
@@ -734,7 +727,7 @@ class ModelLoader:
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
elif self.cfg.sdp_attention:
|
elif self.cfg.sdp_attention:
|
||||||
if self.cfg.diff_attention:
|
if self.cfg.differential_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_sdpa"
|
"differential_sdpa"
|
||||||
@@ -745,7 +738,7 @@ class ModelLoader:
|
|||||||
"sdpa"
|
"sdpa"
|
||||||
)
|
)
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
if self.cfg.diff_attention:
|
if self.cfg.differential_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_eager"
|
"differential_eager"
|
||||||
@@ -755,7 +748,7 @@ class ModelLoader:
|
|||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"eager"
|
"eager"
|
||||||
)
|
)
|
||||||
elif self.cfg.diff_attention:
|
elif self.cfg.differential_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_eager"
|
"differential_eager"
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
from pytest import approx
|
||||||
|
|
||||||
from axolotl.cli import load_cfg
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.cli.evaluate import do_evaluate
|
||||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||||
convert_differential_transformer,
|
convert_differential_transformer,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@@ -19,9 +21,12 @@ def base_config():
|
|||||||
"""Basic config for testing."""
|
"""Basic config for testing."""
|
||||||
return {
|
return {
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
|
||||||
|
],
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "axolotl-ai-co/alpaca_100_test",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -103,7 +108,9 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
|||||||
assert (output_dir / "axolotl_config.yml").exists()
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"])
|
@pytest.mark.parametrize(
|
||||||
|
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||||
|
)
|
||||||
def test_conversion_cli_repoduce_attentions(
|
def test_conversion_cli_repoduce_attentions(
|
||||||
tmp_path: Path, base_config, attention: Optional[str]
|
tmp_path: Path, base_config, attention: Optional[str]
|
||||||
):
|
):
|
||||||
@@ -125,3 +132,42 @@ def test_conversion_cli_repoduce_attentions(
|
|||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
assert (output_dir / "config.json").exists()
|
assert (output_dir / "config.json").exists()
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
|
)
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is True
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
eval_cfg = load_cfg(str(output_dir))
|
||||||
|
eval_cli_args = EvaluateCliArgs()
|
||||||
|
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
|
||||||
|
|
||||||
|
assert list(all_metrics.keys()) == [
|
||||||
|
"train_loss",
|
||||||
|
"train_model_preparation_time",
|
||||||
|
"train_runtime",
|
||||||
|
"train_samples_per_second",
|
||||||
|
"train_steps_per_second",
|
||||||
|
"eval_loss",
|
||||||
|
"eval_model_preparation_time",
|
||||||
|
"eval_runtime",
|
||||||
|
"eval_samples_per_second",
|
||||||
|
"eval_steps_per_second",
|
||||||
|
]
|
||||||
|
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
|
||||||
|
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)
|
||||||
|
|||||||
Reference in New Issue
Block a user