bump transformers==4.52.4 (#2800) [skip ci]
* bump transformers==4.52.4 * don't use hf offline for qwen tokenizer * increase timeout * don't use methodtype * increase timeout * better assertion logging * upgrade deepspeed version too
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
model patcher for chunked top-k kl-div
|
||||
"""
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
@@ -95,4 +94,4 @@ def apply_kernel(model_type):
|
||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||
model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls)
|
||||
model_cls.forward = kldiv_forward_llama_like
|
||||
|
||||
Reference in New Issue
Block a user