Compare commits

...

13 Commits

Author SHA1 Message Date
Wing Lian
5e8c492e3c trainer refactor testing for hf#35567 2025-01-21 11:27:10 -05:00
Wing Lian
9a683536c8 upgrade accelerate also 2025-01-21 10:15:16 -05:00
Wing Lian
faa61a9c3e use official hf release for 4.48.1 2025-01-21 10:15:01 -05:00
Wing Lian
59cb36564d skip check for latest transformers 2025-01-21 10:15:01 -05:00
Wing Lian
50d4d727a0 use wip branch for expected 4.48.1 2025-01-21 10:15:00 -05:00
Wing Lian
0714a49227 move relora test so it runs in a single test thread 2025-01-21 10:15:00 -05:00
Wing Lian
b6daffb788 fix import from mv 2025-01-21 10:15:00 -05:00
Wing Lian
d487e377fa move relora to the patched tests suite 2025-01-21 10:15:00 -05:00
Wing Lian
4cc89f73f0 fix patch 2025-01-21 10:15:00 -05:00
Wing Lian
5b5ba49c46 latest fixes needed for GA in latest transformers 2025-01-21 10:15:00 -05:00
Wing Lian
49b5501fc2 unsloth incompatible with latest transformers 2025-01-21 10:15:00 -05:00
Wing Lian
23389b38b7 bump to latest transformers release 2025-01-21 10:15:00 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
14 changed files with 115 additions and 31 deletions

View File

@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing
batch_flattening:

View File

@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.47.1
transformers @ git+https://github.com/huggingface/transformers.git@mueller-trainer-refactor
tokenizers>=0.21.0
accelerate==1.2.1
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0

View File

@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -14,15 +14,85 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
loss = loss / self.args.gradient_accumulation_steps
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
loss = self.compute_loss(model, inputs)
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
"""
ORIGINAL_LLAMA_FCLM_CODE = """

View File

@@ -706,6 +706,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -22,6 +22,7 @@ def encode_pretraining(
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples[text_column],
@@ -33,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
@@ -204,6 +212,7 @@ def wrap_pretraining_dataset(
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets:

View File

@@ -386,16 +386,15 @@ class ModelLoader:
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()
# if self.cfg.model_config_type == "llama":
# from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,
# patch_flash_attention_forward,
# patch_training_step_for_ga,
# )
#
# patch_flash_attention_forward()
# # patch_forward_for_ga()
# patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(

View File

@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
load_model(cfg, tokenizer, inference=False)
@with_temp_dir
def test_mistral_multipack(self, temp_dir):

View File

@@ -3,8 +3,6 @@ import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_self_attn_is_patchable(),

View File

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,6 +1,8 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
import pytest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
@@ -10,6 +12,7 @@ from axolotl.monkeypatch.trainer_grad_accum import (
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
@pytest.mark.skip("may not be needed for latest transformers version")
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
@@ -17,6 +20,7 @@ class TestTrainerGAIntegration(unittest.TestCase):
"HF transformers Trainer.training_step has changed and isn't patchable",
)
@pytest.mark.skip("may not be needed for latest transformers version")
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(