Update unsloth for torch.cuda.amp deprecation (#2042)
* update deprecated unsloth tirch cuda amp decorator * WIP fix torch.cuda.amp deprecation * lint * laxing torch version requirement * remove use of partial * remove use of partial * lint --------- Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
@@ -14,6 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
|
if torch_version < version.parse("2.4.0"):
|
||||||
|
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||||
|
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||||
|
else:
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_fwd
|
@torch_cuda_amp_custom_fwd
|
||||||
def forward(ctx, forward_function, hidden_states, *args):
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch_cuda_amp_custom_bwd
|
||||||
def backward(ctx, dY):
|
def backward(ctx, dY):
|
||||||
(hidden_states,) = ctx.saved_tensors
|
(hidden_states,) = ctx.saved_tensors
|
||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
|||||||
Reference in New Issue
Block a user