diff --git a/.gitignore b/.gitignore index 7b604d88c..4d7ba15a1 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,6 @@ out/ # vim *.swp + +# symlinked to axolotl-artifacts in docker containers +outputs diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv deleted file mode 100644 index ccbe73358..000000000 --- a/model-out/eval_summary.csv +++ /dev/null @@ -1,6 +0,0 @@ -metric,training,validation -loss,1.8773103952407837,1.915901780128479 -model_preparation_time,0.0051,0.0051 -runtime,89.7635,8.9565 -samples_per_second,20.053,22.33 -steps_per_second,20.053,22.33 diff --git a/outputs b/outputs deleted file mode 120000 index be3c4a823..000000000 --- a/outputs +++ /dev/null @@ -1 +0,0 @@ -/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b..655f3782f 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -3,7 +3,7 @@ CLI to run training on a model """ import logging from pathlib import Path -from typing import Union +from typing import Dict, Union import fire from dotenv import load_dotenv @@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate LOG = logging.getLogger("axolotl.cli.evaluate") -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg, cli_args) -> Dict[str, float]: # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() @@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py index 8903da6d1..a687a3f7c 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -15,7 +15,7 @@ from transformers import HfArgumentParser from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.integrations.differential_transformer.convert import ( - convert_to_diff_attention, + convert_to_differential_attention, ) LOG = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def convert_differential_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention( + model = convert_to_differential_attention( model=model, zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, @@ -111,7 +111,10 @@ def convert_differential_transformer(cfg, cli_args, config_path): data = yaml.safe_load(file) or {} data["base_model"] = cfg.output_dir - data["diff_attention"] = True + data["differential_attention"] = True + data["plugins"] = [ + "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin" + ] with open(output_config_path, "w", encoding="utf-8") as file: yaml.dump(data, file) diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b4ffd6758..f7d35fcf8 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -43,10 +43,12 @@ def merge_input_args(): input_args: List[str] = plugin_manager.get_input_args() plugin_classes = [] dynamic_input = "" + for plugin_args in input_args: plugin_module, plugin_cls = plugin_args.rsplit(".", 1) dynamic_input += f"from {plugin_module} import {plugin_cls}\n" plugin_classes.append(plugin_cls) + if dynamic_input: dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" @@ -62,4 +64,5 @@ def merge_input_args(): "AxolotlConfigWCapabilities" ] return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/integrations/differential_transformer/README.md b/src/axolotl/integrations/differential_transformer/README.md new file mode 100644 index 000000000..f7bd74cbd --- /dev/null +++ b/src/axolotl/integrations/differential_transformer/README.md @@ -0,0 +1,10 @@ +# Differential Transformer + +### Usage + +```yaml +plugins: + - axolotl.integrations.differential_transformer.DifferentialTransformerPlugin + +differential_attention: true +``` diff --git a/src/axolotl/integrations/differential_transformer/__init__.py b/src/axolotl/integrations/differential_transformer/__init__.py index e69de29bb..63741793c 100644 --- a/src/axolotl/integrations/differential_transformer/__init__.py +++ b/src/axolotl/integrations/differential_transformer/__init__.py @@ -0,0 +1,25 @@ +"""Definition of differential transformer plugin.""" + +import logging + +from axolotl.integrations.base import BasePlugin + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerPlugin(BasePlugin): + """ + Plugin for differential transformer integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs" + + def pre_model_load(self, cfg): + """Apply differential attention patch before model loading if enabled.""" + if cfg.differential_attention: + from axolotl.monkeypatch.attention.differential import ( + patch_llama_attention_classes, + ) + + patch_llama_attention_classes() diff --git a/src/axolotl/integrations/differential_transformer/args.py b/src/axolotl/integrations/differential_transformer/args.py new file mode 100644 index 000000000..bd6e01520 --- /dev/null +++ b/src/axolotl/integrations/differential_transformer/args.py @@ -0,0 +1,14 @@ +"""Module for handling differential transfomer input arguments.""" + +import logging +from typing import Optional + +from pydantic import BaseModel + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerArgs(BaseModel): + """Input args for differential transformer.""" + + differential_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index ce3773037..d516f9476 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -80,7 +80,7 @@ def copy_attention_weights( ) -def convert_to_diff_attention( +def convert_to_differential_attention( model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index cab61c148..5ddf04811 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -724,8 +724,6 @@ class AxolotlInputConfig( eager_attention: Optional[bool] = None - diff_attention: Optional[bool] = None - unsloth_cross_entropy_loss: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e98e9f31b..8c8bd0e38 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -444,13 +444,6 @@ class ModelLoader: patch_mistral_cross_entropy() - if self.cfg.diff_attention: - from axolotl.monkeypatch.attention.differential import ( - patch_llama_attention_classes, - ) - - patch_llama_attention_classes() - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -721,7 +714,7 @@ class ModelLoader: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" @@ -734,7 +727,7 @@ class ModelLoader: "flash_attention_2" ) elif self.cfg.sdp_attention: - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_sdpa" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_sdpa" @@ -745,7 +738,7 @@ class ModelLoader: "sdpa" ) elif self.cfg.eager_attention: - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" @@ -755,7 +748,7 @@ class ModelLoader: self.model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) - elif self.cfg.diff_attention: + elif self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" diff --git a/tests/e2e/integrations/test_convert_differential_transformer.py b/tests/e2e/integrations/test_convert_differential_transformer.py index da3aac11a..9ddcf5767 100644 --- a/tests/e2e/integrations/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/test_convert_differential_transformer.py @@ -6,12 +6,14 @@ from typing import Optional import pytest import yaml +from pytest import approx from axolotl.cli import load_cfg +from axolotl.cli.evaluate import do_evaluate from axolotl.cli.integrations.convert_differential_transformer import ( convert_differential_transformer, ) -from axolotl.common.cli import ConvertDiffTransformerCliArgs +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs @pytest.fixture() @@ -19,9 +21,12 @@ def base_config(): """Basic config for testing.""" return { "base_model": "HuggingFaceTB/SmolLM2-135M", + "plugins": [ + "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", + ], "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", + "path": "axolotl-ai-co/alpaca_100_test", "type": "alpaca", }, ], @@ -103,7 +108,9 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config): assert (output_dir / "axolotl_config.yml").exists() -@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"]) +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) def test_conversion_cli_repoduce_attentions( tmp_path: Path, base_config, attention: Optional[str] ): @@ -125,3 +132,42 @@ def test_conversion_cli_repoduce_attentions( assert (output_dir / "model.safetensors").exists() assert (output_dir / "config.json").exists() assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_and_eval_cli(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + eval_cfg = load_cfg(str(output_dir)) + eval_cli_args = EvaluateCliArgs() + all_metrics = do_evaluate(eval_cfg, eval_cli_args) + + assert list(all_metrics.keys()) == [ + "train_loss", + "train_model_preparation_time", + "train_runtime", + "train_samples_per_second", + "train_steps_per_second", + "eval_loss", + "eval_model_preparation_time", + "eval_runtime", + "eval_samples_per_second", + "eval_steps_per_second", + ] + assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) + assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)