Misc fixes 20250130 (#2301)

* misc fixes for garbage collection and L40S w NCCL P2P

* patch bnb fix for triton check

* chore: lint

* change up import

* try patching differently

* remove patch for bnb fix for now

* more verbose checks and tweak train loss threshold
This commit is contained in:
Wing Lian
2025-01-31 08:58:04 -05:00
committed by GitHub
parent 6f294c3d8d
commit cf17649ef3
5 changed files with 14 additions and 5 deletions

View File

@@ -846,6 +846,12 @@ class GCCallback(TrainerCallback):
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
torch.cuda.empty_cache()
gc.collect()

View File

@@ -10,7 +10,7 @@ from accelerate.utils.environment import get_gpu_info
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8: