update trl to 0.17.0 (#2560)

* update trl to 0.17.0

* grpo + vllm no longer supported with 2.5.1 due to vllm constraints

* disable VLLM_USE_V1 for ci

* imporve handle killing off of multiprocessing vllm service

* debug why this doesn't run in CI

* increase vllm wait time

* increase timeout to 5min

* upgrade to vllm 0.8.4

* dump out the vllm log for debugging

* use debug logging

* increase vllm start timeout

* use NVL instead

* disable torch compile cache

* revert some commented checks now that grpo tests are fixed

* increase vllm timeoout back to 5min
This commit is contained in:
Wing Lian
2025-04-27 19:19:53 -04:00
committed by GitHub
parent f9c7c3bb72
commit dc4da4a7e2
7 changed files with 93 additions and 30 deletions

View File

@@ -24,7 +24,7 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -43,7 +43,7 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126

View File

@@ -269,7 +269,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true

View File

@@ -11,13 +11,13 @@ liger-kernel==0.5.8
packaging==23.2
peft==0.15.1
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
trl==0.17.0
hf_xet==1.0.0
hqq==0.2.5

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"]
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"]
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -4,11 +4,14 @@ GRPO test suite
import os
import random
import shutil
import subprocess # nosec B404
import sys
import tempfile
import time
from pathlib import Path
import psutil
import pytest
import requests
import yaml
@@ -21,8 +24,8 @@ from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs
) -> subprocess.Popen:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
@@ -46,10 +49,41 @@ def start_vllm(
# print out the command to be executed
print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
env=cmd_env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
@@ -58,32 +92,51 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False
if wait and host and port:
for _ in range(int(wait)):
for i in range(0, int(wait), period_seconds):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
except requests.exceptions.RequestException as exc:
print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
time.sleep(period_seconds)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
recursive_kill(process)
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
# return the process
return process
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO:
@@ -174,16 +227,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
"NCCL_P2P_LEVEL": "NVL",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -202,10 +256,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
os.kill(vllm_process_id, 9)
recursive_kill(vllm_process)
@pytest.mark.parametrize(
"num_gpus",
@@ -262,16 +320,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -290,7 +349,11 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
os.kill(vllm_process_id, 9)
recursive_kill(vllm_process)