Compare commits

..

4 Commits

Author SHA1 Message Date
Wing Lian
6c49083d8b improve check for base case 2025-01-24 12:02:34 -05:00
Wing Lian
94c226edb3 fixes last eos token not in labels on basic use case 2025-01-24 12:00:06 -05:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
13 changed files with 35 additions and 99 deletions

View File

@@ -6,6 +6,5 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"] ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"] CMD ["sleep", "infinity"]

View File

@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers @ git+https://github.com/huggingface/transformers.git@mueller-trainer-refactor transformers==4.47.1
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.2.1
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.13.0 trl==0.13.0

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys() message_fields = features[field_messages][0].keys()
message_field_role = None message_field_role = None
for key in ["from", "role"]: for key in ["from", "role"]:
if key in message_fields: if key in message_fields:

View File

@@ -14,85 +14,15 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """ ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
loss = loss / self.args.gradient_accumulation_steps
""" """
PATCHED_CONTEXT_CODE = """ PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else: else:
torch.cuda.empty_cache() loss = self.compute_loss(model, inputs)
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
""" """
ORIGINAL_LLAMA_FCLM_CODE = """ ORIGINAL_LLAMA_FCLM_CODE = """

View File

@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably. # Old simple legacy behavior that works reliably.
if ( if (
not self.roles_to_train (not self.roles_to_train or self.roles_to_train == ["assistant"])
and not self.train_on_eos and not self.train_on_eos
and not self.prompter.message_field_training and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail and not self.prompter.message_field_training_detail

View File

@@ -386,15 +386,16 @@ class ModelLoader:
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
# if self.cfg.model_config_type == "llama": if self.cfg.model_config_type == "llama":
# from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga, from axolotl.monkeypatch.trainer_grad_accum import (
# patch_flash_attention_forward, patch_flash_attention_forward,
# patch_training_step_for_ga, patch_forward_for_ga,
# ) patch_training_step_for_ga,
# )
# patch_flash_attention_forward()
# # patch_forward_for_ga() patch_flash_attention_forward()
# patch_training_step_for_ga() patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError( raise ValueError(

View File

@@ -102,5 +102,9 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -49,7 +49,12 @@ class TestModelPatches(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
@with_temp_dir @with_temp_dir
def test_mistral_multipack(self, temp_dir): def test_mistral_multipack(self, temp_dir):

View File

@@ -3,6 +3,8 @@ import unittest
import pytest import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip( @pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers" reason="Unsloth integration will be broken going into latest transformers"
@@ -11,8 +13,6 @@ class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
check_self_attn_is_patchable(), check_self_attn_is_patchable(),

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,8 +1,6 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected.""" """"Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest import unittest
import pytest
from axolotl.monkeypatch.trainer_grad_accum import ( from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable, check_forward_is_patchable,
check_training_step_is_patchable, check_training_step_is_patchable,
@@ -12,7 +10,6 @@ from axolotl.monkeypatch.trainer_grad_accum import (
class TestTrainerGAIntegration(unittest.TestCase): class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests.""" """llama monkeypatch integration tests."""
@pytest.mark.skip("may not be needed for latest transformers version")
def test_train_step_patchable(self): def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
@@ -20,7 +17,6 @@ class TestTrainerGAIntegration(unittest.TestCase):
"HF transformers Trainer.training_step has changed and isn't patchable", "HF transformers Trainer.training_step has changed and isn't patchable",
) )
@pytest.mark.skip("may not be needed for latest transformers version")
def test_model_forward_patchable(self): def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(