Fix: Gradient Accumulation issue (#1980)

* feat: support new arg num_items_in_batch

* use kwargs to manage extra unknown kwargs for now

* upgrade against upstream transformers main

* make sure trl is on latest too

* fix for upgraded trl

* fix: handle trl and transformer signature change

* feat: update trl to handle transformer signature

* RewardDataCollatorWithPadding no longer has max_length

* handle updated signature for tokenizer vs processor class

* invert logic for tokenizer vs processor class

* processing_class, not processor class

* also handle processing class in dpo

* handle model name w model card creation

* upgrade transformers and add a loss check test

* fix install of tbparse requirements

* make sure to add tbparse to req

* feat: revert kwarg to positional kwarg to be explicit

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
NanoCode012
2024-10-25 22:28:23 +07:00
committed by GitHub
parent 1d6a5e2bd6
commit 2501c1a6a3
11 changed files with 170 additions and 98 deletions

View File

@@ -27,7 +27,7 @@ jobs:
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging
pip3 install -e . pip3 install -e .
pip3 install -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name - name: Extract tag name
id: tag id: tag

View File

@@ -47,13 +47,14 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |

View File

@@ -62,7 +62,7 @@ jobs:
run: | run: |
pip3 show torch pip3 show torch
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |

View File

@@ -27,6 +27,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -36,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi fi
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-tests.txt RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -2,3 +2,4 @@ pre-commit
black black
mypy mypy
types-requests types-requests
tbparse

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.13.2
transformers==4.45.2 transformers==4.46.0
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.44.1 bitsandbytes==0.44.1
accelerate==1.0.1 accelerate==1.0.1
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0 gcsfs>=2024.5.0
# adlfs # adlfs
trl==0.9.6 trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -7,6 +7,7 @@ import abc
import gc import gc
import importlib import importlib
import importlib.util import importlib.util
import inspect
import logging import logging
import math import math
import os import os
@@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
PreTrainedModel,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
@@ -666,7 +666,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
# labels = inputs.pop("labels") # labels = inputs.pop("labels")
@@ -674,8 +676,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return self.orpo_compute_loss(
return super().compute_loss(model, inputs, return_outputs=return_outputs) model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -771,7 +783,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2) ).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(self, model, inputs, return_outputs=False): def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs, inputs,
label_pad_token=-100, label_pad_token=-100,
@@ -898,6 +916,7 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model, model,
inputs, inputs,
return_outputs=False, # pylint: disable=unused-argument return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
): ):
input_ids = inputs.pop("input_ids") input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits lm_logits = model(input_ids).logits
@@ -1005,18 +1024,32 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
def tokenize_row( def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict: ) -> Dict:
res = super().tokenize_row(feature, model=model) res = super().tokenize_row(
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys(): for key in res.keys():
res[key] = res[key][1:] res[key] = res[key][1:]
return res return res
def training_step( def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor: ) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs) loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
@@ -1667,12 +1700,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
) )
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs), data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**trainer_kwargs, **trainer_kwargs,
@@ -1713,6 +1751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] ]
if self.cfg.reward_model: if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
@@ -1915,7 +1955,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
@@ -1927,11 +1967,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(
*trainer_cls_args, *trainer_cls_args,
args=training_args, args=training_args,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )

View File

@@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@@ -80,12 +60,6 @@ def get_forward_code() -> str:
return forward return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama": if model_type == "llama":
forward = get_forward_code() from transformers.loss import loss_utils
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace( loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else: else:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")

View File

@@ -260,8 +260,10 @@ def train(
if not cfg.hub_model_id: if not cfg.hub_model_id:
try: try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(
except AttributeError: model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated

View File

@@ -1,22 +1,12 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest import unittest
from axolotl.monkeypatch.unsloth_ import ( from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase): class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(

View File

@@ -0,0 +1,74 @@
"""
E2E tests for packed training
"""
import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"