Fsdp grad accum monkeypatch (#2064)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.1
|
transformers==4.46.2
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.1.0
|
||||||
|
|||||||
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/34645
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i == len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i != len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_fsdp_grad_accum():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
train_loop
|
||||||
|
)
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
assert (
|
||||||
|
ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
), "Original _inner_training_loop code not found"
|
||||||
|
|
||||||
|
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
train_loop = train_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in train_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop", main_process_only=True)
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
||||||
|
patch_training_loop_for_fsdp_grad_accum,
|
||||||
|
)
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -493,6 +496,11 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
|
if cfg.fsdp:
|
||||||
|
try:
|
||||||
|
patch_training_loop_for_fsdp_grad_accum()
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
|
|||||||
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerFSDPIntegration(unittest.TestCase):
|
||||||
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_loop_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_loop_is_patchable(),
|
||||||
|
"HF transformers _inner_training_loop has changed and isn't patchable",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user