latest fixes needed for GA in latest transformers
This commit is contained in:
@@ -14,15 +14,78 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
|||||||
|
|
||||||
ORIGINAL_CONTEXT_CODE = """
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
|
||||||
|
del inputs
|
||||||
|
if (
|
||||||
|
self.args.torch_empty_cache_steps is not None
|
||||||
|
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||||
|
):
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
elif is_torch_mlu_available():
|
||||||
|
torch.mlu.empty_cache()
|
||||||
|
elif is_torch_musa_available():
|
||||||
|
torch.musa.empty_cache()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
elif is_torch_mps_available(min_version="2.0"):
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
else:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
kwargs = {}
|
||||||
|
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||||
|
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
|
kwargs["learning_rate"] = self._get_learning_rate()
|
||||||
|
if self.args.n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
if self.use_apex:
|
||||||
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
# Finally we need to normalize the loss for reporting
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
loss = loss / self.args.gradient_accumulation_steps
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_CONTEXT_CODE = """
|
PATCHED_CONTEXT_CODE = """
|
||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
if self.model_accepts_loss_kwargs:
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
|
||||||
|
del inputs
|
||||||
|
if (
|
||||||
|
self.args.torch_empty_cache_steps is not None
|
||||||
|
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||||
|
):
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
elif is_torch_mlu_available():
|
||||||
|
torch.mlu.empty_cache()
|
||||||
|
elif is_torch_musa_available():
|
||||||
|
torch.musa.empty_cache()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
elif is_torch_mps_available(min_version="2.0"):
|
||||||
|
torch.mps.empty_cache()
|
||||||
else:
|
else:
|
||||||
loss = self.compute_loss(model, inputs)
|
torch.cuda.empty_cache()
|
||||||
|
kwargs = {}
|
||||||
|
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||||
|
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
|
kwargs["learning_rate"] = self._get_learning_rate()
|
||||||
|
if self.args.n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
if self.use_apex:
|
||||||
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
# Finally we need to normalize the loss for reporting
|
||||||
|
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||||
|
loss = loss / self.args.gradient_accumulation_steps
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||||
|
|||||||
@@ -387,13 +387,14 @@ class ModelLoader:
|
|||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
if self.cfg.model_config_type == "llama":
|
if self.cfg.model_config_type == "llama":
|
||||||
from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,; patch_training_step_for_ga,
|
from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,
|
||||||
patch_flash_attention_forward,
|
patch_flash_attention_forward,
|
||||||
|
patch_training_step_for_ga,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_flash_attention_forward()
|
patch_flash_attention_forward()
|
||||||
# patch_forward_for_ga()
|
# patch_forward_for_ga()
|
||||||
# patch_training_step_for_ga()
|
patch_training_step_for_ga()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
|
||||||
"MixtralFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
|
||||||
)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=False)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
|
|
||||||
assert (
|
|
||||||
"MixtralFlashAttention2"
|
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_mistral_multipack(self, temp_dir):
|
def test_mistral_multipack(self, temp_dir):
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from axolotl.monkeypatch.trainer_grad_accum import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("may not be needed for latest transformers version")
|
|
||||||
class TestTrainerGAIntegration(unittest.TestCase):
|
class TestTrainerGAIntegration(unittest.TestCase):
|
||||||
"""llama monkeypatch integration tests."""
|
"""llama monkeypatch integration tests."""
|
||||||
|
|
||||||
@@ -20,6 +19,7 @@ class TestTrainerGAIntegration(unittest.TestCase):
|
|||||||
"HF transformers Trainer.training_step has changed and isn't patchable",
|
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip("may not be needed for latest transformers version")
|
||||||
def test_model_forward_patchable(self):
|
def test_model_forward_patchable(self):
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
Reference in New Issue
Block a user