training fixes, patching, minor cleanup

This commit is contained in:
Dan Saunders
2024-12-13 00:06:22 -05:00
parent e162d36fe9
commit af1d8d69af
8 changed files with 136 additions and 29 deletions

View File

@@ -7,6 +7,7 @@ from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
@@ -50,13 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
raise
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
def convert_diff_transformer(cfg, cli_args, config_path):
try:
# Load model and tokenizer
with warnings.catch_warnings():
@@ -90,8 +85,26 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file) or {}
data["base_model"] = cfg.output_dir
data["diff_attention"] = True
with open(output_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
LOG.info(
Fore.GREEN
@@ -122,10 +135,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
raise
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
fire.Fire(do_cli)

View File

@@ -242,11 +242,6 @@ def merge_lora(
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save converted model",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):

View File

@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""

View File

@@ -88,7 +88,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
for name, child in module.named_children():
if isinstance(child, attention_patterns):
layer_type = type(child).__name__
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
# Choose appropriate differential attention class
if isinstance(child, LlamaSdpaAttention):
@@ -96,6 +95,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
else:
attention_class = LlamaDifferentialAttention
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,

View File

@@ -0,0 +1,46 @@
"""Patches related to differential transformers implementation."""
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from .multihead_diffattn import (
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
)
def patch_transformers():
"""Patch transformers to support differential attention"""
# Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
# Store original method for use in our patch
# original_autoset = PreTrainedModel._autoset_attn_implementation
@classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
config._attn_implementation_autoset = True # pylint: disable=protected-access
attn_implementation = getattr(config, "_attn_implementation", None)
valid_impls = [
None,
"eager",
"sdpa",
"flash_attention_2",
"differential_eager",
"differential_sdpa",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message + ".")
return config
# Apply patch
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
new_autoset
)

View File

@@ -87,7 +87,7 @@ def train(
)
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load the model and tokenizer
# Load the model
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."

View File

@@ -724,6 +724,8 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None
diff_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None

View File

@@ -710,24 +710,60 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
print(
self.cfg.flash_attention,
self.cfg.sdp_attention,
self.cfg.eager_attention,
self.cfg.diff_attention,
)
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
"differential_eager"
)
if "attn_implementation" in self.model_kwargs:
print(self.model_kwargs["attn_implementation"])
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
@@ -816,6 +852,8 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# self.model._attn_implementation_autoset = False
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1030,6 +1068,10 @@ class ModelLoader:
integrate_rope_embeddings()
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
from axolotl.integrations.diff_transformer.patches import patch_transformers
patch_transformers()
self.apply_patches()
self.set_auto_model_loader()
self.set_device_map_config()