update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts: * bump datasets, tokenizer, trl * remove log workarounds in trl * bump lm-eval * remove unsloth_ import from critical path * remove llama fa2 from conftest * unsloth breaks with latest upstream
This commit is contained in:
@@ -14,11 +14,11 @@ packaging==23.2
|
|||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.47.1
|
transformers==4.47.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.21.0
|
||||||
accelerate==1.2.1
|
accelerate==1.2.1
|
||||||
datasets==3.1.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.12.1
|
trl==0.13.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -53,7 +53,7 @@ zstandard==0.22.0
|
|||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
# lm eval harness
|
# lm eval harness
|
||||||
lm_eval==0.4.4
|
lm_eval==0.4.7
|
||||||
langdetect==1.0.9
|
langdetect==1.0.9
|
||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.2
|
axolotl-contribs-lgpl==0.0.3
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from packaging import version
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
@@ -984,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
return super().log(logs, start_time)
|
||||||
try:
|
|
||||||
return super().log(logs, start_time)
|
|
||||||
except TypeError:
|
|
||||||
return super().log(logs) # transformers<=4.46
|
|
||||||
return super().log(logs) # transformers<=4.46
|
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
@@ -1173,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
|
||||||
# logs either has 'loss' or 'eval_loss'
|
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
|
||||||
del self._stored_metrics[train_eval]
|
|
||||||
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
||||||
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
|
||||||
)
|
|
||||||
# transformers<=4.46
|
|
||||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1197,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
|
||||||
# logs either has 'loss' or 'eval_loss'
|
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
|
||||||
del self._stored_metrics[train_eval]
|
|
||||||
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
||||||
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
|
||||||
)
|
|
||||||
# transformers<=4.46
|
|
||||||
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1221,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
|
||||||
# logs either has 'loss' or 'eval_loss'
|
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
|
||||||
# train metrics should have no prefix, eval should have 'eval_'
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
# accumulate average metrics from sums and lengths
|
|
||||||
for split in ["chosen", "rejected"]:
|
|
||||||
if f"count/{split}" in self._stored_metrics[train_eval]:
|
|
||||||
count_sum = (
|
|
||||||
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
|
||||||
.sum()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
for metric in ["rewards", "logps", "logits"]:
|
|
||||||
logs[f"{prefix}{metric}/{split}"] = (
|
|
||||||
torch.Tensor(
|
|
||||||
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
|
||||||
)
|
|
||||||
.sum()
|
|
||||||
.item()
|
|
||||||
/ count_sum
|
|
||||||
)
|
|
||||||
# delete obsolete metric
|
|
||||||
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
|
||||||
del self._stored_metrics[train_eval][f"count/{split}"]
|
|
||||||
# calculate reward margin
|
|
||||||
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
|
||||||
logs[f"{prefix}rewards/margins"] = (
|
|
||||||
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
|
||||||
)
|
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
|
||||||
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
|
||||||
del self._stored_metrics[train_eval]
|
|
||||||
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
||||||
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
|
||||||
)
|
|
||||||
# transformers<=4.46
|
|
||||||
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1272,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
|
||||||
# logs either has 'loss' or 'eval_loss'
|
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
|
||||||
del self._stored_metrics[train_eval]
|
|
||||||
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
||||||
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
|
||||||
)
|
|
||||||
# transformers<=4.46
|
|
||||||
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1296,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
|
||||||
# TODO remove once trl supports the updated to the Trainer.log method
|
|
||||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
||||||
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
|
||||||
logs, start_time
|
|
||||||
)
|
|
||||||
# transformers<=4.46
|
|
||||||
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import logging
|
|||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import logging
|
|||||||
from transformers import LlamaForCausalLM, Trainer
|
from transformers import LlamaForCausalLM, Trainer
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
|
||||||
import types
|
import types
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
ORIGINAL_QKV_CODE = """
|
||||||
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
|||||||
raise ValueError("Unsupported model type")
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
|
||||||
try:
|
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
|
||||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
|
||||||
except AttributeError:
|
|
||||||
return code, ""
|
|
||||||
return code, spaces
|
|
||||||
|
|
||||||
|
|
||||||
self_attn_lora_patched = False # pylint: disable=invalid-name
|
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Shared utils for the monkeypatches
|
Shared utils for the monkeypatches
|
||||||
"""
|
"""
|
||||||
from typing import Optional
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
|||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
|
try:
|
||||||
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
except AttributeError:
|
||||||
|
return code, ""
|
||||||
|
return code, spaces
|
||||||
|
|||||||
@@ -120,13 +120,12 @@ def temp_dir():
|
|||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
original_fa2_forward = LlamaFlashAttention2.forward
|
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||||
original_llama_attn_forward = LlamaAttention.forward
|
original_llama_attn_forward = LlamaAttention.forward
|
||||||
original_llama_forward = LlamaForCausalLM.forward
|
original_llama_forward = LlamaForCausalLM.forward
|
||||||
original_trainer_inner_training_loop = (
|
original_trainer_inner_training_loop = (
|
||||||
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
|
|||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
LlamaFlashAttention2.forward = original_fa2_forward
|
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||||
LlamaAttention.forward = original_llama_attn_forward
|
LlamaAttention.forward = original_llama_attn_forward
|
||||||
LlamaForCausalLM.forward = original_llama_forward
|
LlamaForCausalLM.forward = original_llama_forward
|
||||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
|
|||||||
("transformers.models.llama",),
|
("transformers.models.llama",),
|
||||||
(
|
(
|
||||||
"transformers.models.llama.modeling_llama",
|
"transformers.models.llama.modeling_llama",
|
||||||
["LlamaFlashAttention2", "LlamaAttention"],
|
[
|
||||||
|
# "LlamaFlashAttention2",
|
||||||
|
"LlamaAttention",
|
||||||
|
],
|
||||||
),
|
),
|
||||||
("transformers.trainer",),
|
("transformers.trainer",),
|
||||||
("transformers", ["Trainer"]),
|
("transformers", ["Trainer"]),
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Unsloth integration will be broken going into latest transformers"
|
||||||
|
)
|
||||||
class TestUnslothIntegration(unittest.TestCase):
|
class TestUnslothIntegration(unittest.TestCase):
|
||||||
"""Unsloth monkeypatch integration tests."""
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Unsloth integration will be broken going into latest transformers"
|
||||||
|
)
|
||||||
class TestUnslothQLoRA:
|
class TestUnslothQLoRA:
|
||||||
"""
|
"""
|
||||||
Test class for Unsloth QLoRA Llama models
|
Test class for Unsloth QLoRA Llama models
|
||||||
|
|||||||
Reference in New Issue
Block a user