From 2501c1a6a3392b658fcd5d5ace3d5fb71b633afa Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 25 Oct 2024 22:28:23 +0700 Subject: [PATCH] Fix: Gradient Accumulation issue (#1980) * feat: support new arg num_items_in_batch * use kwargs to manage extra unknown kwargs for now * upgrade against upstream transformers main * make sure trl is on latest too * fix for upgraded trl * fix: handle trl and transformer signature change * feat: update trl to handle transformer signature * RewardDataCollatorWithPadding no longer has max_length * handle updated signature for tokenizer vs processor class * invert logic for tokenizer vs processor class * processing_class, not processor class * also handle processing class in dpo * handle model name w model card creation * upgrade transformers and add a loss check test * fix install of tbparse requirements * make sure to add tbparse to req * feat: revert kwarg to positional kwarg to be explicit --------- Co-authored-by: Wing Lian --- .github/workflows/pypi.yml | 2 +- .github/workflows/tests-nightly.yml | 3 +- .github/workflows/tests.yml | 2 +- cicd/Dockerfile.jinja | 3 +- requirements-dev.txt | 1 + requirements.txt | 4 +- src/axolotl/core/trainer_builder.py | 72 ++++++++++++--- src/axolotl/monkeypatch/unsloth_.py | 89 +++++-------------- src/axolotl/train.py | 6 +- tests/e2e/patched/test_unsloth_integration.py | 12 +-- tests/e2e/test_packing_loss.py | 74 +++++++++++++++ 11 files changed, 170 insertions(+), 98 deletions(-) create mode 100644 tests/e2e/test_packing_loss.py diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 885239d18..04dbc6385 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -27,7 +27,7 @@ jobs: run: | pip3 install wheel packaging pip3 install -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Extract tag name id: tag diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 56eaae239..90b1e23cd 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -47,13 +47,14 @@ jobs: sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt + sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt - name: Install dependencies run: | pip3 install --upgrade pip pip3 install --upgrade packaging pip3 install -U -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 130ac6e7b..ba50adfd3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: run: | pip3 show torch pip3 install -U -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests run: | diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 3b082a15b..8ce655005 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -27,6 +27,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ + sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \ fi RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ @@ -36,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ fi # So we can test the Docker image -RUN pip install -r requirements-tests.txt +RUN pip install -r requirements-dev.txt -r requirements-tests.txt # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ diff --git a/requirements-dev.txt b/requirements-dev.txt index 4b5df167b..dcc729d1b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pre-commit black mypy types-requests +tbparse diff --git a/requirements.txt b/requirements.txt index 067be05cf..b6e9a554e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.2 -transformers==4.45.2 +transformers==4.46.0 tokenizers>=0.20.1 bitsandbytes==0.44.1 accelerate==1.0.1 @@ -43,7 +43,7 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl==0.9.6 +trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f05efe7b8..319ea7be5 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -7,6 +7,7 @@ import abc import gc import importlib import importlib.util +import inspect import logging import math import os @@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( EarlyStoppingCallback, - PreTrainedModel, Trainer, TrainerCallback, TrainingArguments, @@ -666,7 +666,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer): return DataLoader(bench_dataset, **dataloader_params) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): # use one's weighted cross entropy loss calc # if self.args.sample_packing: # labels = inputs.pop("labels") @@ -674,8 +676,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer): # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss if self.args.orpo_alpha: - return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) - return super().compute_loss(model, inputs, return_outputs=return_outputs) + return self.orpo_compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + return super().compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): @@ -771,7 +783,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): ).squeeze(2) return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) - def orpo_compute_loss(self, model, inputs, return_outputs=False): + def orpo_compute_loss( + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, # pylint: disable=unused-argument + ): concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( inputs, label_pad_token=-100, @@ -898,6 +916,7 @@ class AxolotlMambaTrainer(AxolotlTrainer): model, inputs, return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument ): input_ids = inputs.pop("input_ids") lm_logits = model(input_ids).logits @@ -1005,18 +1024,32 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) def tokenize_row( - self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None + self, + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, ) -> Dict: - res = super().tokenize_row(feature, model=model) - if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: + res = super().tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None: for key in res.keys(): res[key] = res[key][1:] return res def training_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs) + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) gc.collect() torch.cuda.empty_cache() return loss @@ -1667,12 +1700,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return_tensors="pt", **data_collator_kwargs, ) + sig = inspect.signature(trainer_cls) + if "processing_class" in sig.parameters.keys(): + trainer_kwargs["processing_class"] = self.tokenizer + else: + trainer_kwargs["tokenizer"] = self.tokenizer + trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - tokenizer=self.tokenizer, data_collator=self.build_collator(training_args, **data_collator_kwargs), callbacks=self.get_callbacks(), **trainer_kwargs, @@ -1713,6 +1751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding + if "max_length" in kwargs: + kwargs.pop("max_length") elif use_batch_sampler_collator: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq @@ -1915,7 +1955,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["generate_during_eval"] = True + dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] @@ -1927,11 +1967,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + + sig = inspect.signature(trainer_cls) + if "processing_class" in sig.parameters.keys(): + dpo_trainer_kwargs["processing_class"] = self.tokenizer + else: + dpo_trainer_kwargs["tokenizer"] = self.tokenizer + dpo_trainer = trainer_cls( *trainer_cls_args, args=training_args, train_dataset=self.train_dataset, - tokenizer=self.tokenizer, callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 3d42ad17f..c8272ac73 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import ( LOG = get_logger("axolotl.monkeypatch.unsloth") -ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) -""" - -PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - ) -""" - ORIGINAL_QKV_CODE = """ query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -80,12 +60,6 @@ def get_forward_code() -> str: return forward -def check_cel_is_patchable() -> bool: - forward = get_forward_code() - forward, _ = detab_code(forward) - return ORIGINAL_CEL_CODE in forward - - def get_self_attn_code() -> str: forward = inspect.getsource(LlamaFlashAttention2.forward) return forward @@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: + from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss + + def UnslothForCausalLMLoss( # pylint: disable=invalid-name + logits, + labels, + vocab_size: int, # pylint: disable=unused-argument + num_items_in_batch: int = None, + ignore_index: int = -100, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss = fast_cross_entropy_loss( + logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch + ) + return loss + if model_type == "llama": - forward = get_forward_code() - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" + from transformers.loss import loss_utils - forward = forward.replace( - "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" - ) - forward = forward.replace( - "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", - "", - ) - forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) - forward = forward.replace( - "def forward(", - "def fast_cross_entropy_loss_forward(", - 1, - ) - - # load imports necessary - import transformers.models.llama.modeling_llama - - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", - globals(), - ) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) - LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment] else: raise ValueError("Unsupported model type") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4ce28d8a3..5fde4d384 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -260,8 +260,10 @@ def train( if not cfg.hub_model_id: try: - trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) - except AttributeError: + trainer.create_model_card( + model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") + ) + except (AttributeError, UnicodeDecodeError): pass elif cfg.hub_model_id: # defensively push to the hub to ensure the model card is updated diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index 39c7abb1c..888274286 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -1,22 +1,12 @@ """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" import unittest -from axolotl.monkeypatch.unsloth_ import ( - check_cel_is_patchable, - check_self_attn_is_patchable, -) +from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable class TestUnslothIntegration(unittest.TestCase): """Unsloth monkeypatch integration tests.""" - def test_is_cel_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_cel_is_patchable(), - "HF transformers loss code has changed and isn't patchable", - ) - def test_is_self_attn_patchable(self): # ensures the current version of transformers has loss code that matches our patching code self.assertTrue( diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py new file mode 100644 index 000000000..73f9e60ba --- /dev/null +++ b/tests/e2e/test_packing_loss.py @@ -0,0 +1,74 @@ +""" +E2E tests for packed training +""" + +import logging +import os +import unittest + +from tbparse import SummaryReader +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import most_recent_subdir, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPackedLlama(unittest.TestCase): + """ + Test case for Packed training of llama models + """ + + @with_temp_dir + def test_loss_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM-135M", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high"