training fixes, patching, minor cleanup
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import HfArgumentParser
|
||||
@@ -50,13 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
raise
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
with warnings.catch_warnings():
|
||||
@@ -90,8 +85,26 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
# Save if requested
|
||||
if cfg.output_dir:
|
||||
# Save model and tokenizer
|
||||
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
tokenizer.save_pretrained(cfg.output_dir)
|
||||
|
||||
# Modify config to reflect new path / differential attention
|
||||
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
|
||||
LOG.info("Saving updated config to %s", output_config_path)
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
data = yaml.safe_load(file) or {}
|
||||
|
||||
data["base_model"] = cfg.output_dir
|
||||
data["diff_attention"] = True
|
||||
|
||||
with open(output_config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(data, file)
|
||||
else:
|
||||
LOG.info("Not saving converted model to disk")
|
||||
LOG.info("Pass --output-dir path/to/save to save model")
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
@@ -122,10 +135,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
raise
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
convert_diff_transformer(cfg, cli_args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -242,11 +242,6 @@ def merge_lora(
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--output-dir",
|
||||
type=click.Path(path_type=str),
|
||||
help="Directory to save converted model",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
|
||||
@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@@ -88,7 +88,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, attention_patterns):
|
||||
layer_type = type(child).__name__
|
||||
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
||||
|
||||
# Choose appropriate differential attention class
|
||||
if isinstance(child, LlamaSdpaAttention):
|
||||
@@ -96,6 +95,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
else:
|
||||
attention_class = LlamaDifferentialAttention
|
||||
|
||||
logger.info(
|
||||
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
|
||||
)
|
||||
|
||||
# Create new diff attn layer
|
||||
new_attention = attention_class(
|
||||
config=module.config if hasattr(module, "config") else model.config,
|
||||
|
||||
46
src/axolotl/integrations/diff_transformer/patches.py
Normal file
46
src/axolotl/integrations/diff_transformer/patches.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Patches related to differential transformers implementation."""
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
from .multihead_diffattn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
|
||||
def patch_transformers():
|
||||
"""Patch transformers to support differential attention"""
|
||||
|
||||
# Add our attention class to the registry
|
||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||
|
||||
# Store original method for use in our patch
|
||||
# original_autoset = PreTrainedModel._autoset_attn_implementation
|
||||
|
||||
@classmethod
|
||||
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
|
||||
config._attn_implementation_autoset = True # pylint: disable=protected-access
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
valid_impls = [
|
||||
None,
|
||||
"eager",
|
||||
"sdpa",
|
||||
"flash_attention_2",
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||
)
|
||||
raise ValueError(message + ".")
|
||||
|
||||
return config
|
||||
|
||||
# Apply patch
|
||||
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
|
||||
new_autoset
|
||||
)
|
||||
@@ -88,7 +88,7 @@ def train(
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
# Load the model
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
|
||||
@@ -727,6 +727,8 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: Optional[bool] = None
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
|
||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||
unsloth_lora_mlp: Optional[bool] = None
|
||||
unsloth_lora_qkv: Optional[bool] = None
|
||||
|
||||
@@ -710,24 +710,60 @@ class ModelLoader:
|
||||
"""
|
||||
sample packing uses custom FA2 patch
|
||||
"""
|
||||
print(
|
||||
self.cfg.flash_attention,
|
||||
self.cfg.sdp_attention,
|
||||
self.cfg.eager_attention,
|
||||
self.cfg.diff_attention,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs[
|
||||
"attn_implementation"
|
||||
] = "differential_flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_flash_attention_2"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
elif self.cfg.sdp_attention:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_sdpa"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_eager"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
"differential_eager"
|
||||
)
|
||||
|
||||
if "attn_implementation" in self.model_kwargs:
|
||||
print(self.model_kwargs["attn_implementation"])
|
||||
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
@@ -816,6 +852,8 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
|
||||
# self.model._attn_implementation_autoset = False
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
@@ -1030,6 +1068,10 @@ class ModelLoader:
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
from axolotl.integrations.diff_transformer.patches import patch_transformers
|
||||
|
||||
patch_transformers()
|
||||
|
||||
self.apply_patches()
|
||||
self.set_auto_model_loader()
|
||||
self.set_device_map_config()
|
||||
|
||||
Reference in New Issue
Block a user