From 8a7a0b07dc5ce6da9171e28a0818b447b6d7cea2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Jan 2025 21:17:57 -0500 Subject: [PATCH] support for latest transformers release 4.48.1 (#2256) --- cicd/cicd.sh | 3 +- requirements.txt | 4 +- src/axolotl/core/trainer_builder.py | 1 + src/axolotl/monkeypatch/trainer_grad_accum.py | 308 ------------------ .../monkeypatch/transformers_fa_utils.py | 67 ++++ src/axolotl/utils/models.py | 18 +- tests/e2e/multigpu/test_llama.py | 16 +- tests/e2e/patched/test_mixtral_samplepack.py | 6 +- tests/e2e/patched/test_model_patches.py | 7 +- tests/e2e/patched/test_unsloth_integration.py | 4 +- tests/e2e/solo/__init__.py | 0 tests/e2e/{ => solo}/test_relora_llama.py | 2 +- tests/patched/test_llama_trainer_ga.py | 25 -- 13 files changed, 98 insertions(+), 363 deletions(-) delete mode 100644 src/axolotl/monkeypatch/trainer_grad_accum.py create mode 100644 src/axolotl/monkeypatch/transformers_fa_utils.py create mode 100644 tests/e2e/solo/__init__.py rename tests/e2e/{ => solo}/test_relora_llama.py (97%) delete mode 100644 tests/patched/test_llama_trainer_ga.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 91926127f..34a30db44 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ +pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ -pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ +pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/requirements.txt b/requirements.txt index 1f7ac7bba..52e146411 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,9 @@ liger-kernel==0.5.2 packaging==23.2 peft==0.14.0 -transformers==4.47.1 +transformers==4.48.1 tokenizers>=0.21.0 -accelerate==1.2.1 +accelerate==1.3.0 datasets==3.2.0 deepspeed==0.16.1 trl==0.13.0 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6f1bae1ef..edc842994 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1079,6 +1079,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): super().__init__(*args, **kwargs) self.dataset_tags = dataset_tags self.optimizer = None + self.model_accepts_loss_kwargs = False def create_optimizer(self): if self.args.loraplus_lr_ratio is None: diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py deleted file mode 100644 index 05d706704..000000000 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -fix for FSDP gradient accumulation -see https://github.com/huggingface/transformers/pull/35128 -""" -import inspect -import logging - -from transformers import LlamaForCausalLM, Trainer -from transformers.modeling_flash_attention_utils import _flash_attention_forward - -from axolotl.monkeypatch.utils import detab_code - -LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") - -ORIGINAL_CONTEXT_CODE = """ - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) -""" - -PATCHED_CONTEXT_CODE = """ - with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) - else: - loss = self.compute_loss(model, inputs) -""" - -ORIGINAL_LLAMA_FCLM_CODE = """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) -""" - -PATCHED_LLAMA_FCLM_CODE = """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention - num_items_in_batch = kwargs.pop("num_items_in_batch", None) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs) -""" - - -def get_training_step_code() -> str: - training_step = inspect.getsource( - Trainer.training_step # pylint: disable=protected-access - ) - return training_step - - -def check_training_step_is_patchable() -> bool: - training_step = get_training_step_code() - training_step, _ = detab_code(training_step) - return ORIGINAL_CONTEXT_CODE in training_step - - -def patch_training_step_for_ga(): - """ - monkeypatch for fixing the training loop for gradient accumulation - """ - - try: - training_step = get_training_step_code() - except OSError: - return - Trainer._original_training_step = training_step # pylint: disable=protected-access - training_step, _ = detab_code(training_step) - if ORIGINAL_CONTEXT_CODE not in training_step: - return - # assert ( - # ORIGINAL_CONTEXT_CODE in training_step - # ), "Original training_step code not found" - - training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) - training_step = training_step.replace( - "def training_step(", - "def _fixed_training_step(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in training_step: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching training_step") - Trainer.training_step = ( # pylint: disable=protected-access - _fixed_training_step # pylint: disable=undefined-variable # noqa: F821 - ) - - -def get_model_forward_code() -> str: - forward = inspect.getsource( - LlamaForCausalLM.forward # pylint: disable=protected-access - ) - return forward - - -def check_forward_is_patchable() -> bool: - forward = get_model_forward_code() - forward, _ = detab_code(forward) - return ORIGINAL_LLAMA_FCLM_CODE in forward - - -def patch_forward_for_ga(): - """ - monkeypatch for fixing the training loop for gradient accumulation - """ - - try: - forward = get_model_forward_code() - except OSError: - return - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - if ORIGINAL_LLAMA_FCLM_CODE not in forward: - return - # assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found" - - forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE) - forward = forward.replace( - "def forward(", - "def _fixed_forward(", - 1, - ) - - # load imports necessary - import transformers.models.llama.modeling_llama - - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching forward") - LlamaForCausalLM.forward = ( # pylint: disable=protected-access - _fixed_forward # pylint: disable=undefined-variable # noqa: F821 - ) - - -ORIGINAL_TRAINER_CODE = """ - context = ( - functools.partial(self.accelerator.no_sync, model=model) - if i != len(batch_samples) - 1 - else contextlib.nullcontext - ) - with context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) -""" - -PATCHED_TRAINER_CODE = """ - disable_deepspeed_no_sync = ( - self.accelerator.distributed_type == DistributedType.DEEPSPEED - # and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients() - ) - context = ( - functools.partial(self.accelerator.no_sync, model=model) - if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync - else contextlib.nullcontext - ) - with context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) -""" - - -def get_training_loop_code() -> str: - training_loop = inspect.getsource( - Trainer._inner_training_loop # pylint: disable=protected-access - ) - return training_loop - - -def check_training_loop_is_patchable() -> bool: - training_loop = get_training_loop_code() - training_loop, _ = detab_code(training_loop) - return ORIGINAL_TRAINER_CODE in training_loop - - -def patch_training_loop_for_deepspeed_0_16_x(): - """ - monkeypatch for fixing the training loop for deepspeed GA - - see https://github.com/huggingface/transformers/pull/35157 - """ - - try: - training_loop = get_training_loop_code() - except OSError: - return - Trainer._original_inner_training_loop = ( # pylint: disable=protected-access - training_loop - ) - training_loop, _ = detab_code(training_loop) - if ORIGINAL_TRAINER_CODE not in training_loop: - return - - training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) - training_loop = training_loop.replace( - "def _inner_training_loop(", - "def _fixed_inner_training_loop(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in training_loop: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer._inner_training_loop = ( # pylint: disable=protected-access - _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 - ) - - -def patch_flash_attention_forward(): - """ - monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch - """ - - import transformers.modeling_flash_attention_utils - - def proxy_flash_attention_forward(*args, **kwargs): - kwargs.pop("num_items_in_batch", None) - - return _flash_attention_forward(*args, **kwargs) - - transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access - proxy_flash_attention_forward - ) - transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access - proxy_flash_attention_forward - ) diff --git a/src/axolotl/monkeypatch/transformers_fa_utils.py b/src/axolotl/monkeypatch/transformers_fa_utils.py new file mode 100644 index 000000000..f34ecb8c0 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers_fa_utils.py @@ -0,0 +1,67 @@ +""" +see https://github.com/huggingface/transformers/pull/35834 +""" + +import logging +from functools import partial +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def fixed_fa_peft_integration_check( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, + preferred_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value (`torch.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`torch.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + preferred_dtype (`torch.dtype`, *optional*): + The preferred dtype to convert the attention tensors to regardless of the + target dtype. + """ + if target_dtype is None and preferred_dtype is None: + return query, key, value + + if preferred_dtype and target_dtype != preferred_dtype: + target_dtype = preferred_dtype + + # check if any of query, key, or value are in float32. If so, cast them back to target dtype. + if any(module.dtype == torch.float32 for module in [query, key, value]): + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + +def patch_fa_peft_integration(): + import transformers.modeling_flash_attention_utils + + transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial( + fixed_fa_peft_integration_check, preferred_dtype=None + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4a665c111..c4b8f05b9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,23 +380,19 @@ class ModelLoader: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + if self.cfg.adapter: + from axolotl.monkeypatch.transformers_fa_utils import ( + patch_fa_peft_integration, + ) + + patch_fa_peft_integration() + if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper if self.cfg.flash_attention: self.patch_attention() - if self.cfg.model_config_type == "llama": - from axolotl.monkeypatch.trainer_grad_accum import ( - patch_flash_attention_forward, - patch_forward_for_ga, - patch_training_step_for_ga, - ) - - patch_flash_attention_forward() - patch_forward_for_ga() - patch_training_step_for_ga() - if self.cfg.sample_packing and self.cfg.s2_attention: raise ValueError( "Received `sample_packing=true` and `s2_attention=true`; however, \ diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 7135ad805..bdbd99587 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -63,6 +63,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -127,6 +128,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -201,6 +203,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -223,8 +226,12 @@ class TestMultiGPULlama: ] ) + loss_threshold = 2.3 check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", ) def test_dpo_qlora_ddp(self, temp_dir): @@ -275,6 +282,7 @@ class TestMultiGPULlama: "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, + "bf16": True, } ) @@ -297,8 +305,12 @@ class TestMultiGPULlama: ] ) + loss_threshold = 2.3 check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", ) @pytest.mark.parametrize( diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 156dac7e8..8746c923b 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 78b01be64..c6a13af19 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase): ) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=False) - - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + load_model(cfg, tokenizer, inference=False) @with_temp_dir def test_mistral_multipack(self, temp_dir): diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index bc6476dab..403d26147 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -3,8 +3,6 @@ import unittest import pytest -from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable - @pytest.mark.skip( reason="Unsloth integration will be broken going into latest transformers" @@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase): """Unsloth monkeypatch integration tests.""" def test_is_self_attn_patchable(self): + from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable + # ensures the current version of transformers has loss code that matches our patching code self.assertTrue( check_self_attn_is_patchable(), diff --git a/tests/e2e/solo/__init__.py b/tests/e2e/solo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py similarity index 97% rename from tests/e2e/test_relora_llama.py rename to tests/e2e/solo/test_relora_llama.py index 6c785dc86..191f76f64 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -13,7 +13,7 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, check_tensorboard, with_temp_dir +from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/patched/test_llama_trainer_ga.py b/tests/patched/test_llama_trainer_ga.py deleted file mode 100644 index 58c229cf3..000000000 --- a/tests/patched/test_llama_trainer_ga.py +++ /dev/null @@ -1,25 +0,0 @@ -""""Test module for checking whether the Hugging Face Transformers is working as expected.""" -import unittest - -from axolotl.monkeypatch.trainer_grad_accum import ( - check_forward_is_patchable, - check_training_step_is_patchable, -) - - -class TestTrainerGAIntegration(unittest.TestCase): - """llama monkeypatch integration tests.""" - - def test_train_step_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_training_step_is_patchable(), - "HF transformers Trainer.training_step has changed and isn't patchable", - ) - - def test_model_forward_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_forward_is_patchable(), - "HF transformers LlamaForCausalLM.forward has changed and isn't patchable", - )