Compare commits
4 Commits
hf-trainer
...
eos-hell
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c49083d8b | ||
|
|
94c226edb3 | ||
|
|
8fb72cbc0b | ||
|
|
bb9d4102c4 |
@@ -6,6 +6,5 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
|
||||
@@ -13,9 +13,9 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@mueller-trainer-refactor
|
||||
transformers==4.47.1
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
accelerate==1.2.1
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.13.0
|
||||
|
||||
@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
|
||||
)
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features["conversations"][0].keys()
|
||||
message_fields = features[field_messages][0].keys()
|
||||
message_field_role = None
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
|
||||
@@ -14,85 +14,15 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available(min_version="2.0"):
|
||||
torch.mps.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
kwargs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if num_items_in_batch is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
"""
|
||||
|
||||
PATCHED_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available(min_version="2.0"):
|
||||
torch.mps.empty_cache()
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
kwargs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
loss = self.compute_loss(model, inputs)
|
||||
"""
|
||||
|
||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||
|
||||
@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Old simple legacy behavior that works reliably.
|
||||
if (
|
||||
not self.roles_to_train
|
||||
(not self.roles_to_train or self.roles_to_train == ["assistant"])
|
||||
and not self.train_on_eos
|
||||
and not self.prompter.message_field_training
|
||||
and not self.prompter.message_field_training_detail
|
||||
|
||||
@@ -386,15 +386,16 @@ class ModelLoader:
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
# if self.cfg.model_config_type == "llama":
|
||||
# from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,
|
||||
# patch_flash_attention_forward,
|
||||
# patch_training_step_for_ga,
|
||||
# )
|
||||
#
|
||||
# patch_flash_attention_forward()
|
||||
# # patch_forward_for_ga()
|
||||
# patch_training_step_for_ga()
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_flash_attention_forward,
|
||||
patch_forward_for_ga,
|
||||
patch_training_step_for_ga,
|
||||
)
|
||||
|
||||
patch_flash_attention_forward()
|
||||
patch_forward_for_ga()
|
||||
patch_training_step_for_ga()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
raise ValueError(
|
||||
|
||||
@@ -102,5 +102,9 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -49,7 +49,12 @@ class TestModelPatches(unittest.TestCase):
|
||||
)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
model, _ = load_model(cfg, tokenizer, inference=False)
|
||||
|
||||
assert (
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
|
||||
@@ -3,6 +3,8 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
@@ -11,8 +13,6 @@ class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
def test_is_self_attn_patchable(self):
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_self_attn_is_patchable(),
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -1,8 +1,6 @@
|
||||
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
check_forward_is_patchable,
|
||||
check_training_step_is_patchable,
|
||||
@@ -12,7 +10,6 @@ from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
class TestTrainerGAIntegration(unittest.TestCase):
|
||||
"""llama monkeypatch integration tests."""
|
||||
|
||||
@pytest.mark.skip("may not be needed for latest transformers version")
|
||||
def test_train_step_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
@@ -20,7 +17,6 @@ class TestTrainerGAIntegration(unittest.TestCase):
|
||||
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||
)
|
||||
|
||||
@pytest.mark.skip("may not be needed for latest transformers version")
|
||||
def test_model_forward_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user