Update unsloth for torch.cuda.amp deprecation (#2042)

* update deprecated unsloth tirch cuda amp  decorator

* WIP fix torch.cuda.amp deprecation

* lint

* laxing torch version requirement

* remove use of partial

* remove use of partial

* lint

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
Sunny Liu
2024-11-13 15:17:34 -05:00
committed by GitHub
parent c5eb9ea2c2
commit 342935cff3

View File

@@ -14,6 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version < version.parse("2.4.0"):
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
return output
@staticmethod
@torch.cuda.amp.custom_bwd
@torch_cuda_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()