Update unsloth for torch.cuda.amp deprecation (#2042)
* update deprecated unsloth tirch cuda amp decorator * WIP fix torch.cuda.amp deprecation * lint * laxing torch version requirement * remove use of partial * remove use of partial * lint --------- Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
@@ -14,6 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
|
||||
if torch_version < version.parse("2.4.0"):
|
||||
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, dY):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
|
||||
Reference in New Issue
Block a user