support for latest transformers release 4.48.1 (#2256)
This commit is contained in:
@@ -1079,6 +1079,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
"""
|
||||
fix for FSDP gradient accumulation
|
||||
see https://github.com/huggingface/transformers/pull/35128
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import LlamaForCausalLM, Trainer
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
"""
|
||||
|
||||
PATCHED_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
"""
|
||||
|
||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
"""
|
||||
|
||||
PATCHED_LLAMA_FCLM_CODE = """
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||
"""
|
||||
|
||||
|
||||
def get_training_step_code() -> str:
|
||||
training_step = inspect.getsource(
|
||||
Trainer.training_step # pylint: disable=protected-access
|
||||
)
|
||||
return training_step
|
||||
|
||||
|
||||
def check_training_step_is_patchable() -> bool:
|
||||
training_step = get_training_step_code()
|
||||
training_step, _ = detab_code(training_step)
|
||||
return ORIGINAL_CONTEXT_CODE in training_step
|
||||
|
||||
|
||||
def patch_training_step_for_ga():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for gradient accumulation
|
||||
"""
|
||||
|
||||
try:
|
||||
training_step = get_training_step_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
||||
training_step, _ = detab_code(training_step)
|
||||
if ORIGINAL_CONTEXT_CODE not in training_step:
|
||||
return
|
||||
# assert (
|
||||
# ORIGINAL_CONTEXT_CODE in training_step
|
||||
# ), "Original training_step code not found"
|
||||
|
||||
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||
training_step = training_step.replace(
|
||||
"def training_step(",
|
||||
"def _fixed_training_step(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in training_step:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching training_step")
|
||||
Trainer.training_step = ( # pylint: disable=protected-access
|
||||
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def get_model_forward_code() -> str:
|
||||
forward = inspect.getsource(
|
||||
LlamaForCausalLM.forward # pylint: disable=protected-access
|
||||
)
|
||||
return forward
|
||||
|
||||
|
||||
def check_forward_is_patchable() -> bool:
|
||||
forward = get_model_forward_code()
|
||||
forward, _ = detab_code(forward)
|
||||
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
||||
|
||||
|
||||
def patch_forward_for_ga():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for gradient accumulation
|
||||
"""
|
||||
|
||||
try:
|
||||
forward = get_model_forward_code()
|
||||
except OSError:
|
||||
return
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
||||
return
|
||||
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
||||
|
||||
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def _fixed_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching forward")
|
||||
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
disable_deepspeed_no_sync = (
|
||||
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
|
||||
)
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
"""
|
||||
|
||||
|
||||
def get_training_loop_code() -> str:
|
||||
training_loop = inspect.getsource(
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_training_loop_is_patchable() -> bool:
|
||||
training_loop = get_training_loop_code()
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
return ORIGINAL_TRAINER_CODE in training_loop
|
||||
|
||||
|
||||
def patch_training_loop_for_deepspeed_0_16_x():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for deepspeed GA
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/35157
|
||||
"""
|
||||
|
||||
try:
|
||||
training_loop = get_training_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||
training_loop
|
||||
)
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||
return
|
||||
|
||||
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||
training_loop = training_loop.replace(
|
||||
"def _inner_training_loop(",
|
||||
"def _fixed_inner_training_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in training_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def patch_flash_attention_forward():
|
||||
"""
|
||||
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
|
||||
"""
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
def proxy_flash_attention_forward(*args, **kwargs):
|
||||
kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
return _flash_attention_forward(*args, **kwargs)
|
||||
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
|
||||
proxy_flash_attention_forward
|
||||
)
|
||||
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
|
||||
proxy_flash_attention_forward
|
||||
)
|
||||
67
src/axolotl/monkeypatch/transformers_fa_utils.py
Normal file
67
src/axolotl/monkeypatch/transformers_fa_utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
see https://github.com/huggingface/transformers/pull/35834
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fixed_fa_peft_integration_check(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
PEFT usually casts the layer norms in float32 for training stability reasons
|
||||
therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
target_dtype (`torch.dtype`, *optional*):
|
||||
The dtype to convert the attention tensors to. Conversion can be ignored by
|
||||
not providing the target dtype.
|
||||
preferred_dtype (`torch.dtype`, *optional*):
|
||||
The preferred dtype to convert the attention tensors to regardless of the
|
||||
target dtype.
|
||||
"""
|
||||
if target_dtype is None and preferred_dtype is None:
|
||||
return query, key, value
|
||||
|
||||
if preferred_dtype and target_dtype != preferred_dtype:
|
||||
target_dtype = preferred_dtype
|
||||
|
||||
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
|
||||
if any(module.dtype == torch.float32 for module in [query, key, value]):
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
return query, key, value
|
||||
|
||||
|
||||
def patch_fa_peft_integration():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
|
||||
fixed_fa_peft_integration_check, preferred_dtype=None
|
||||
)
|
||||
@@ -380,23 +380,19 @@ class ModelLoader:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
if self.cfg.adapter:
|
||||
from axolotl.monkeypatch.transformers_fa_utils import (
|
||||
patch_fa_peft_integration,
|
||||
)
|
||||
|
||||
patch_fa_peft_integration()
|
||||
|
||||
if self.cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_flash_attention_forward,
|
||||
patch_forward_for_ga,
|
||||
patch_training_step_for_ga,
|
||||
)
|
||||
|
||||
patch_flash_attention_forward()
|
||||
patch_forward_for_ga()
|
||||
patch_training_step_for_ga()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
raise ValueError(
|
||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||
|
||||
Reference in New Issue
Block a user