support for latest transformers release 4.48.1 (#2256)

This commit is contained in:
Wing Lian
2025-01-23 21:17:57 -05:00
committed by GitHub
parent 8fb72cbc0b
commit 8a7a0b07dc
13 changed files with 98 additions and 363 deletions

View File

@@ -1079,6 +1079,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:

View File

@@ -1,308 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -0,0 +1,67 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -380,23 +380,19 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \