training fixes, patching, minor cleanup
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Union
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
import yaml
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
@@ -50,13 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config, **kwargs)
|
|
||||||
parser = HfArgumentParser(TrainerCliArgs)
|
|
||||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -90,8 +85,26 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
# Save if requested
|
# Save if requested
|
||||||
if cfg.output_dir:
|
if cfg.output_dir:
|
||||||
|
# Save model and tokenizer
|
||||||
LOG.info("Saving converted model to %s", cfg.output_dir)
|
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
tokenizer.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
# Modify config to reflect new path / differential attention
|
||||||
|
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
|
||||||
|
LOG.info("Saving updated config to %s", output_config_path)
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
|
data = yaml.safe_load(file) or {}
|
||||||
|
|
||||||
|
data["base_model"] = cfg.output_dir
|
||||||
|
data["diff_attention"] = True
|
||||||
|
|
||||||
|
with open(output_config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(data, file)
|
||||||
|
else:
|
||||||
|
LOG.info("Not saving converted model to disk")
|
||||||
|
LOG.info("Pass --output-dir path/to/save to save model")
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
@@ -122,10 +135,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
|
convert_diff_transformer(cfg, cli_args, config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
||||||
)
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -242,11 +242,6 @@ def merge_lora(
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
|
||||||
"--output-dir",
|
|
||||||
type=click.Path(path_type=str),
|
|
||||||
help="Directory to save converted model",
|
|
||||||
)
|
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def convert_diff_transformer(config: str, **kwargs):
|
def convert_diff_transformer(config: str, **kwargs):
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
|
||||||
so it can't be used as a mixin.
|
so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
|||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, attention_patterns):
|
if isinstance(child, attention_patterns):
|
||||||
layer_type = type(child).__name__
|
layer_type = type(child).__name__
|
||||||
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
|
||||||
|
|
||||||
# Choose appropriate differential attention class
|
# Choose appropriate differential attention class
|
||||||
if isinstance(child, LlamaSdpaAttention):
|
if isinstance(child, LlamaSdpaAttention):
|
||||||
@@ -96,6 +95,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
|||||||
else:
|
else:
|
||||||
attention_class = LlamaDifferentialAttention
|
attention_class = LlamaDifferentialAttention
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create new diff attn layer
|
# Create new diff attn layer
|
||||||
new_attention = attention_class(
|
new_attention = attention_class(
|
||||||
config=module.config if hasattr(module, "config") else model.config,
|
config=module.config if hasattr(module, "config") else model.config,
|
||||||
|
|||||||
46
src/axolotl/integrations/diff_transformer/patches.py
Normal file
46
src/axolotl/integrations/diff_transformer/patches.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Patches related to differential transformers implementation."""
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||||
|
|
||||||
|
from .multihead_diffattn import (
|
||||||
|
LlamaDifferentialAttention,
|
||||||
|
LlamaDifferentialSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers():
|
||||||
|
"""Patch transformers to support differential attention"""
|
||||||
|
|
||||||
|
# Add our attention class to the registry
|
||||||
|
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||||
|
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||||
|
|
||||||
|
# Store original method for use in our patch
|
||||||
|
# original_autoset = PreTrainedModel._autoset_attn_implementation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
|
||||||
|
config._attn_implementation_autoset = True # pylint: disable=protected-access
|
||||||
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
|
valid_impls = [
|
||||||
|
None,
|
||||||
|
"eager",
|
||||||
|
"sdpa",
|
||||||
|
"flash_attention_2",
|
||||||
|
"differential_eager",
|
||||||
|
"differential_sdpa",
|
||||||
|
]
|
||||||
|
if attn_implementation not in valid_impls:
|
||||||
|
message = (
|
||||||
|
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||||
|
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||||
|
)
|
||||||
|
raise ValueError(message + ".")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Apply patch
|
||||||
|
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
new_autoset
|
||||||
|
)
|
||||||
@@ -87,7 +87,7 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model
|
||||||
msg = "loading model"
|
msg = "loading model"
|
||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
msg += " and peft_config..."
|
msg += " and peft_config..."
|
||||||
|
|||||||
@@ -724,6 +724,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: Optional[bool] = None
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
diff_attention: Optional[bool] = None
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
|||||||
@@ -710,24 +710,60 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
|
print(
|
||||||
|
self.cfg.flash_attention,
|
||||||
|
self.cfg.sdp_attention,
|
||||||
|
self.cfg.eager_attention,
|
||||||
|
self.cfg.diff_attention,
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
if self.cfg.diff_attention:
|
||||||
"flash_attention_2"
|
self.model_kwargs[
|
||||||
)
|
"attn_implementation"
|
||||||
|
] = "differential_flash_attention_2"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"differential_flash_attention_2"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"flash_attention_2"
|
||||||
|
)
|
||||||
elif self.cfg.sdp_attention:
|
elif self.cfg.sdp_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
if self.cfg.diff_attention:
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||||
"sdpa"
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
)
|
"differential_sdpa"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"sdpa"
|
||||||
|
)
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
if self.cfg.diff_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"differential_eager"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model_kwargs["attn_implementation"] = "eager"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"eager"
|
||||||
|
)
|
||||||
|
elif self.cfg.diff_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"eager"
|
"differential_eager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "attn_implementation" in self.model_kwargs:
|
||||||
|
print(self.model_kwargs["attn_implementation"])
|
||||||
|
|
||||||
if self.cfg.low_cpu_mem_usage:
|
if self.cfg.low_cpu_mem_usage:
|
||||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
@@ -816,6 +852,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
|
||||||
|
# self.model._attn_implementation_autoset = False
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -1030,6 +1068,10 @@ class ModelLoader:
|
|||||||
integrate_rope_embeddings()
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
|
from axolotl.integrations.diff_transformer.patches import patch_transformers
|
||||||
|
|
||||||
|
patch_transformers()
|
||||||
|
|
||||||
self.apply_patches()
|
self.apply_patches()
|
||||||
self.set_auto_model_loader()
|
self.set_auto_model_loader()
|
||||||
self.set_device_map_config()
|
self.set_device_map_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user