update trl to 0.17.0 (#2560)

* update trl to 0.17.0

* grpo + vllm no longer supported with 2.5.1 due to vllm constraints

* disable VLLM_USE_V1 for ci

* imporve handle killing off of multiprocessing vllm service

* debug why this doesn't run in CI

* increase vllm wait time

* increase timeout to 5min

* upgrade to vllm 0.8.4

* dump out the vllm log for debugging

* use debug logging

* increase vllm start timeout

* use NVL instead

* disable torch compile cache

* revert some commented checks now that grpo tests are fixed

* increase vllm timeoout back to 5min
This commit is contained in:
Wing Lian
2025-04-27 19:19:53 -04:00
committed by GitHub
parent f9c7c3bb72
commit dc4da4a7e2
7 changed files with 93 additions and 30 deletions

View File

@@ -24,7 +24,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: vllm axolotl_extras:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -43,7 +43,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: vllm axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126 - cuda: 126

View File

@@ -269,7 +269,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"

View File

@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true

View File

@@ -11,13 +11,13 @@ liger-kernel==0.5.8
packaging==23.2 packaging==23.2
peft==0.15.1 peft==0.15.2
transformers==4.51.3 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.16.1 trl==0.17.0
hf_xet==1.0.0 hf_xet==1.0.0
hqq==0.2.5 hqq==0.2.5

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7): if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"] extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append( _install_requires.append(
"xformers==0.0.29.post2" "xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6 ) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"] extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:

View File

@@ -4,11 +4,14 @@ GRPO test suite
import os import os
import random import random
import shutil
import subprocess # nosec B404 import subprocess # nosec B404
import sys import sys
import tempfile
import time import time
from pathlib import Path from pathlib import Path
import psutil
import pytest import pytest
import requests import requests
import yaml import yaml
@@ -21,8 +24,8 @@ from tests.e2e.utils import require_vllm
def start_vllm( def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs model: str, env: dict, wait: int | None = None, quiet=False, **kwargs
) -> int: ) -> subprocess.Popen:
""" """
helper function to start the VLLM server in the background, mostly for testing purposes helper function to start the VLLM server in the background, mostly for testing purposes
""" """
@@ -46,10 +49,41 @@ def start_vllm(
# print out the command to be executed # print out the command to be executed
print(" ".join(cmd)) print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id # start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with process = subprocess.Popen( # pylint: disable=consider-using-with
cmd, cmd,
env=env, env=cmd_env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE, stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603 ) # nosec B603
@@ -58,32 +92,51 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})") print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds # wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False started = False
if wait and host and port: if wait and host and port:
for _ in range(int(wait)): for i in range(0, int(wait), period_seconds):
try: try:
response = requests.get(f"http://{host}:{port}", timeout=1) response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]: if int(response.status_code) in [200, 404]:
started = True started = True
break break
except requests.exceptions.RequestException: except requests.exceptions.RequestException as exc:
pass print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running # also check if the process.pid is still running
if not process.poll() is None: if not process.poll() is None:
break break
time.sleep(1) time.sleep(period_seconds)
if wait and not started: if wait and not started:
print( print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs." f"VLLM server process did not start within {wait} seconds. Please check your server logs."
) )
process.kill() recursive_kill(process)
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id # return the process
return process.pid return process
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO: class TestGRPO:
@@ -174,16 +227,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "LOC", "NCCL_P2P_LEVEL": "NVL",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0", "VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process_id = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=120, wait=300,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -202,10 +256,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env}, env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
os.kill(vllm_process_id, 9) recursive_kill(vllm_process)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
@@ -262,16 +320,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0", "VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process_id = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=120, wait=300,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -290,7 +349,11 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env}, env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
os.kill(vllm_process_id, 9) recursive_kill(vllm_process)