bump transformers for fsdp-grad-accum fix, remove patch (#2079)
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.2
|
transformers==4.46.3
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.1.0
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
"""
|
|
||||||
fix for FSDP gradient accumulation
|
|
||||||
see https://github.com/huggingface/transformers/pull/34645
|
|
||||||
"""
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
|
||||||
|
|
||||||
ORIGINAL_CONTEXT_CODE = """
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i == len(batch_samples) - 1
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_CONTEXT_CODE = """
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i != len(batch_samples) - 1
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_loop_code() -> str:
|
|
||||||
training_loop = inspect.getsource(
|
|
||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
return training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def check_training_loop_is_patchable() -> bool:
|
|
||||||
train_loop = get_training_loop_code()
|
|
||||||
train_loop, _ = detab_code(train_loop)
|
|
||||||
return ORIGINAL_CONTEXT_CODE in train_loop
|
|
||||||
|
|
||||||
|
|
||||||
def patch_training_loop_for_fsdp_grad_accum():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
|
||||||
"""
|
|
||||||
|
|
||||||
train_loop = get_training_loop_code()
|
|
||||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
train_loop
|
|
||||||
)
|
|
||||||
train_loop, _ = detab_code(train_loop)
|
|
||||||
assert (
|
|
||||||
ORIGINAL_CONTEXT_CODE in train_loop
|
|
||||||
), "Original _inner_training_loop code not found"
|
|
||||||
|
|
||||||
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
|
||||||
train_loop = train_loop.replace(
|
|
||||||
"def _inner_training_loop(",
|
|
||||||
"def _fixed_inner_training_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in train_loop:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching _inner_training_loop", main_process_only=True)
|
|
||||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
@@ -16,9 +16,6 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
|
||||||
patch_training_loop_for_fsdp_grad_accum,
|
|
||||||
)
|
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -496,11 +493,6 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
if cfg.fsdp:
|
|
||||||
try:
|
|
||||||
patch_training_loop_for_fsdp_grad_accum()
|
|
||||||
except AssertionError:
|
|
||||||
pass
|
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerFSDPIntegration(unittest.TestCase):
|
|
||||||
"""Unsloth monkeypatch integration tests."""
|
|
||||||
|
|
||||||
def test_train_loop_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_training_loop_is_patchable(),
|
|
||||||
"HF transformers _inner_training_loop has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user