Various fixes for CI, save_only_model for RL, prevent packing multiprocessing deadlocks (#2661)

* lean mistral ft tests, remove e2e torch 2.4.1 test

* make sure to pass save_only_model for RL

* more tests to make ci leaner, add cleanup to modal ci

* fix module for import in e2e tests

* use mp spawn to prevent deadlocks with packing

* make sure cleanup shell script is executable when cloned out
This commit is contained in:
Wing Lian
2025-05-12 10:51:18 -04:00
committed by GitHub
parent 47e0e71bc8
commit c7b6790614
13 changed files with 190 additions and 99 deletions

View File

@@ -1057,6 +1057,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes

View File

@@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from multiprocessing import cpu_count, get_context
from typing import Iterable, Union
import numba
@@ -126,6 +126,7 @@ def pack_parallel(
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
):
"""
Pack sequences into bins using parallel processing
@@ -137,7 +138,9 @@ def pack_parallel(
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch.
Set to None to use system default.
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
@@ -154,9 +157,33 @@ def pack_parallel(
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
mp_ctx = None
if mp_start_method:
try:
mp_ctx = get_context(mp_start_method)
except ValueError:
LOG.warning(
f"Failed to get multiprocessing context '{mp_start_method}'. "
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
)
mp_ctx = (
None # Fallback to default context if specified one is not available
)
if num_processes == 1:
LOG.debug("Using single process for pack_parallel, running sequentially.")
for task_args in tasks:
group_bins = _process_group(task_args)
all_bins.extend(group_bins)
else:
# Use ProcessPoolExecutor only if num_processes > 1
# Pass mp_context if available
with ProcessPoolExecutor(
max_workers=num_processes, mp_context=mp_ctx
) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins