feat: upgrade transformers to v4.56.1 (#3127)
* feat: upgrade transformers to v4.56 * fix handling of CP/SP now that position_ids are default even for unpacked sequences * feat: monkeypatch list_repo_templates * fix: apply patch for tests only * see if updated main works at least * fix: update to patch release and remove monkeypatch * remove fsdp2 eval patch --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -13,7 +13,7 @@ packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.55.4
|
||||
transformers==4.56.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
|
||||
@@ -80,13 +80,7 @@ class PatchManager:
|
||||
patch_maybe_log_save_evaluate,
|
||||
)
|
||||
|
||||
patch_fsdp2 = (
|
||||
self.cfg.torch_compile
|
||||
and self.cfg.fsdp_config
|
||||
and self.cfg.fsdp_version == 2
|
||||
)
|
||||
|
||||
patch_evaluation_loop(patch_fsdp2)
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
|
||||
@@ -28,15 +28,6 @@ PATCHED_EVAL_CODE = {
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||
}
|
||||
|
||||
ORIGINAL_FSDP2_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_FSDP2_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||
|
||||
@@ -46,13 +37,7 @@ def check_evaluation_loop_is_patchable() -> bool:
|
||||
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||
|
||||
|
||||
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
||||
|
||||
|
||||
def patch_evaluation_loop(patch_fsdp2: bool):
|
||||
def patch_evaluation_loop():
|
||||
"""Patch the evaluation_loop method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
@@ -75,13 +60,6 @@ def patch_evaluation_loop(patch_fsdp2: bool):
|
||||
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
||||
)
|
||||
|
||||
# Apply FSDP2 eval guard patch if needed
|
||||
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
||||
)
|
||||
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
||||
|
||||
# Rename the function to avoid conflicts
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
"def evaluation_loop(",
|
||||
|
||||
@@ -48,10 +48,10 @@ def apply_sequence_parallelism(
|
||||
- The original sequence length before padding.
|
||||
- The number of padding tokens added.
|
||||
"""
|
||||
original_seq_len = batch["input_ids"].size(1)
|
||||
batch_size, original_seq_len = batch["input_ids"].shape
|
||||
|
||||
# Update ring attention params if needed
|
||||
if batch.get("position_ids") is not None:
|
||||
if batch.get("position_ids") is not None and batch_size == 1:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
else:
|
||||
# If position_ids aren't already in the batch, create them
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
check_evaluation_loop_is_fsdp2_patchable,
|
||||
check_evaluation_loop_is_patchable,
|
||||
check_maybe_log_save_evaluate_is_patchable,
|
||||
)
|
||||
@@ -20,7 +19,6 @@ class TestTrainerLossCalc(unittest.TestCase):
|
||||
the patched code changes upstream.
|
||||
"""
|
||||
assert check_evaluation_loop_is_patchable()
|
||||
assert check_evaluation_loop_is_fsdp2_patchable()
|
||||
assert check_maybe_log_save_evaluate_is_patchable()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user