From 5b5ba49c46758f9430e12eaa049192b7032e3db8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 13:36:47 -0500 Subject: [PATCH] latest fixes needed for GA in latest transformers --- src/axolotl/monkeypatch/trainer_grad_accum.py | 69 ++++++++++++++++++- src/axolotl/utils/models.py | 5 +- tests/e2e/patched/test_mixtral_samplepack.py | 6 +- tests/e2e/patched/test_model_patches.py | 7 +- tests/patched/test_llama_trainer_ga.py | 2 +- 5 files changed, 72 insertions(+), 17 deletions(-) diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 05d706704..8fc498cff 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -14,15 +14,78 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") ORIGINAL_CONTEXT_CODE = """ with self.compute_loss_context_manager(): + if self.model_accepts_loss_kwargs: + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + kwargs = {} + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # Finally we need to normalize the loss for reporting + if num_items_in_batch is None: + loss = loss / self.args.gradient_accumulation_steps """ PATCHED_CONTEXT_CODE = """ with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() else: - loss = self.compute_loss(model, inputs) + torch.cuda.empty_cache() + kwargs = {} + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # Finally we need to normalize the loss for reporting + if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: + loss = loss / self.args.gradient_accumulation_steps """ ORIGINAL_LLAMA_FCLM_CODE = """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bff72a1d7..d90d67407 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -387,13 +387,14 @@ class ModelLoader: self.patch_attention() if self.cfg.model_config_type == "llama": - from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,; patch_training_step_for_ga, + from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga, patch_flash_attention_forward, + patch_training_step_for_ga, ) patch_flash_attention_forward() # patch_forward_for_ga() - # patch_training_step_for_ga() + patch_training_step_for_ga() if self.cfg.sample_packing and self.cfg.s2_attention: raise ValueError( diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 156dac7e8..8746c923b 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 78b01be64..c6a13af19 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase): ) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=False) - - assert ( - "MixtralFlashAttention2" - in model.model.layers[0].self_attn.__class__.__name__ - ) + load_model(cfg, tokenizer, inference=False) @with_temp_dir def test_mistral_multipack(self, temp_dir): diff --git a/tests/patched/test_llama_trainer_ga.py b/tests/patched/test_llama_trainer_ga.py index 8a0b338ae..8b354c57e 100644 --- a/tests/patched/test_llama_trainer_ga.py +++ b/tests/patched/test_llama_trainer_ga.py @@ -9,7 +9,6 @@ from axolotl.monkeypatch.trainer_grad_accum import ( ) -@pytest.mark.skip("may not be needed for latest transformers version") class TestTrainerGAIntegration(unittest.TestCase): """llama monkeypatch integration tests.""" @@ -20,6 +19,7 @@ class TestTrainerGAIntegration(unittest.TestCase): "HF transformers Trainer.training_step has changed and isn't patchable", ) + @pytest.mark.skip("may not be needed for latest transformers version") def test_model_forward_patchable(self): # ensures the current version of transformers has loss code that matches our patching code self.assertTrue(