transformers 4.47.1 (#2187)
* transformers 4.47.1 * drop monkeypatches * can't remove patches yet * make flash attention forward ignore the loss kwargs * patch the flash attention in the modeling arch too * remove fsdp and deepspeed patches * cleanup PR * bump accelerate and torchao, also logically reorder/group requirements * meant to include torchao * use official patch release
This commit is contained in:
@@ -11,22 +11,27 @@ liger-kernel==0.4.2
|
|||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.47.0
|
transformers==4.47.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
accelerate==1.2.0
|
accelerate==1.2.1
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
|
trl==0.12.1
|
||||||
|
|
||||||
|
optimum==1.16.2
|
||||||
|
hf_transfer
|
||||||
|
sentencepiece
|
||||||
|
gradio==3.50.2
|
||||||
|
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
sentencepiece
|
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
optimum==1.16.2
|
|
||||||
hf_transfer
|
|
||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
@@ -36,7 +41,6 @@ scipy
|
|||||||
scikit-learn==1.4.2
|
scikit-learn==1.4.2
|
||||||
nvidia-ml-py==12.560.30
|
nvidia-ml-py==12.560.30
|
||||||
art
|
art
|
||||||
gradio==3.50.2
|
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
|
||||||
@@ -45,7 +49,6 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.12.1
|
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
@@ -55,5 +58,5 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.5.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|||||||
@@ -32,5 +32,5 @@ else:
|
|||||||
raise RuntimeError(f"Torch = {v} too new!")
|
raise RuntimeError(f"Torch = {v} too new!")
|
||||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||||
print(
|
print(
|
||||||
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from transformers import LlamaForCausalLM, Trainer
|
from transformers import LlamaForCausalLM, Trainer
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
@@ -13,10 +14,7 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
|||||||
|
|
||||||
ORIGINAL_CONTEXT_CODE = """
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
if self.model_accepts_loss_kwargs:
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
loss = self.compute_loss(model, inputs)
|
|
||||||
else:
|
|
||||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_CONTEXT_CODE = """
|
PATCHED_CONTEXT_CODE = """
|
||||||
@@ -288,3 +286,23 @@ def patch_training_loop_for_deepspeed_0_16_x():
|
|||||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_flash_attention_forward():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
def proxy_flash_attention_forward(*args, **kwargs):
|
||||||
|
kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
return _flash_attention_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
|
||||||
|
proxy_flash_attention_forward
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
|
||||||
|
proxy_flash_attention_forward
|
||||||
|
)
|
||||||
|
|||||||
@@ -380,19 +380,6 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
|
||||||
patch_training_loop_for_fsdp,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_training_loop_for_fsdp()
|
|
||||||
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
|
|
||||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
|
||||||
patch_training_loop_for_deepspeed_0_16_x,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_training_loop_for_deepspeed_0_16_x()
|
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
@@ -401,10 +388,12 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.model_config_type == "llama":
|
if self.cfg.model_config_type == "llama":
|
||||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
patch_flash_attention_forward,
|
||||||
patch_forward_for_ga,
|
patch_forward_for_ga,
|
||||||
patch_training_step_for_ga,
|
patch_training_step_for_ga,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
patch_flash_attention_forward()
|
||||||
patch_forward_for_ga()
|
patch_forward_for_ga()
|
||||||
patch_training_step_for_ga()
|
patch_training_step_for_ga()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user