Compare commits
12 Commits
diffusion-
...
hf-trainer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e8c492e3c | ||
|
|
9a683536c8 | ||
|
|
faa61a9c3e | ||
|
|
59cb36564d | ||
|
|
50d4d727a0 | ||
|
|
0714a49227 | ||
|
|
b6daffb788 | ||
|
|
d487e377fa | ||
|
|
4cc89f73f0 | ||
|
|
5b5ba49c46 | ||
|
|
49b5501fc2 | ||
|
|
23389b38b7 |
@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -13,9 +13,9 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.47.1
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@mueller-trainer-refactor
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.2.1
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.13.0
|
||||
|
||||
@@ -14,15 +14,85 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available(min_version="2.0"):
|
||||
torch.mps.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
kwargs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if num_items_in_batch is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
"""
|
||||
|
||||
PATCHED_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available(min_version="2.0"):
|
||||
torch.mps.empty_cache()
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
kwargs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
"""
|
||||
|
||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||
|
||||
@@ -386,16 +386,15 @@ class ModelLoader:
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_flash_attention_forward,
|
||||
patch_forward_for_ga,
|
||||
patch_training_step_for_ga,
|
||||
)
|
||||
|
||||
patch_flash_attention_forward()
|
||||
patch_forward_for_ga()
|
||||
patch_training_step_for_ga()
|
||||
# if self.cfg.model_config_type == "llama":
|
||||
# from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,
|
||||
# patch_flash_attention_forward,
|
||||
# patch_training_step_for_ga,
|
||||
# )
|
||||
#
|
||||
# patch_flash_attention_forward()
|
||||
# # patch_forward_for_ga()
|
||||
# patch_training_step_for_ga()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
raise ValueError(
|
||||
|
||||
@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = load_model(cfg, tokenizer, inference=False)
|
||||
|
||||
assert (
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
|
||||
@@ -3,8 +3,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
def test_is_self_attn_patchable(self):
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_self_attn_is_patchable(),
|
||||
|
||||
0
tests/e2e/solo/__init__.py
Normal file
0
tests/e2e/solo/__init__.py
Normal file
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -1,6 +1,8 @@
|
||||
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
check_forward_is_patchable,
|
||||
check_training_step_is_patchable,
|
||||
@@ -10,6 +12,7 @@ from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
class TestTrainerGAIntegration(unittest.TestCase):
|
||||
"""llama monkeypatch integration tests."""
|
||||
|
||||
@pytest.mark.skip("may not be needed for latest transformers version")
|
||||
def test_train_step_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
@@ -17,6 +20,7 @@ class TestTrainerGAIntegration(unittest.TestCase):
|
||||
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||
)
|
||||
|
||||
@pytest.mark.skip("may not be needed for latest transformers version")
|
||||
def test_model_forward_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user