Misc fixes 20250130 (#2301)
* misc fixes for garbage collection and L40S w NCCL P2P * patch bnb fix for triton check * chore: lint * change up import * try patching differently * remove patch for bnb fix for now * more verbose checks and tweak train loss threshold
This commit is contained in:
@@ -846,6 +846,12 @@ class GCCallback(TrainerCallback):
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if state.global_step % self.gc_steps == 0:
|
||||
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def on_epoch_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -10,7 +10,7 @@ from accelerate.utils.environment import get_gpu_info
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada"}
|
||||
unsupported_devices = {"RTX 6000 Ada", "L40S"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
|
||||
Reference in New Issue
Block a user