Skip to content

Commit 5e666f7

Browse files
authored
[Bugfix][Ray] Set the cuda context eagerly in the ray worker (#19583)
1 parent e3a3e4d commit 5e666f7

File tree

4 files changed

+107
-0
lines changed

4 files changed

+107
-0
lines changed

‎.buildkite/test-pipeline.yaml‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,15 @@ steps:
271271
commands:
272272
- pytest -v -s prefix_caching
273273

274+
275+
- label: Platform Tests (CUDA)
276+
mirror_hardwares: [amdexperimental]
277+
source_file_dependencies:
278+
- vllm/
279+
- tests/cuda
280+
commands:
281+
- pytest -v -s cuda/test_cuda_context.py
282+
274283
- label: Samplers Test # 36min
275284
mirror_hardwares: [amdexperimental]
276285
source_file_dependencies:

‎tests/cuda/test_cuda_context.py‎

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import ctypes
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.platforms import current_platform
11+
12+
13+
def check_cuda_context():
14+
"""Check CUDA driver context status"""
15+
try:
16+
cuda = ctypes.CDLL('libcuda.so')
17+
device = ctypes.c_int()
18+
result = cuda.cuCtxGetDevice(ctypes.byref(device))
19+
return (True, device.value) if result == 0 else (False, None)
20+
except Exception:
21+
return False, None
22+
23+
24+
def run_cuda_test_in_thread(device_input, expected_device_id):
25+
"""Run CUDA context test in separate thread for isolation"""
26+
try:
27+
# New thread should have no CUDA context initially
28+
valid_before, device_before = check_cuda_context()
29+
if valid_before:
30+
return False, \
31+
"CUDA context should not exist in new thread, " \
32+
f"got device {device_before}"
33+
34+
# Test setting CUDA context
35+
current_platform.set_device(device_input)
36+
37+
# Verify context is created correctly
38+
valid_after, device_id = check_cuda_context()
39+
if not valid_after:
40+
return False, "CUDA context should be valid after set_cuda_context"
41+
if device_id != expected_device_id:
42+
return False, \
43+
f"Expected device {expected_device_id}, got {device_id}"
44+
45+
return True, "Success"
46+
except Exception as e:
47+
return False, f"Exception in thread: {str(e)}"
48+
49+
50+
class TestSetCudaContext:
51+
"""Test suite for the set_cuda_context function."""
52+
53+
@pytest.mark.skipif(not current_platform.is_cuda(),
54+
reason="CUDA not available")
55+
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
56+
argvalues=[
57+
(0, 0),
58+
(torch.device('cuda:0'), 0),
59+
('cuda:0', 0),
60+
],
61+
ids=["int", "torch_device", "string"])
62+
def test_set_cuda_context_parametrized(self, device_input,
63+
expected_device_id):
64+
"""Test setting CUDA context in isolated threads."""
65+
with ThreadPoolExecutor(max_workers=1) as executor:
66+
future = executor.submit(run_cuda_test_in_thread, device_input,
67+
expected_device_id)
68+
success, message = future.result(timeout=30)
69+
assert success, message
70+
71+
@pytest.mark.skipif(not current_platform.is_cuda(),
72+
reason="CUDA not available")
73+
def test_set_cuda_context_invalid_device_type(self):
74+
"""Test error handling for invalid device type."""
75+
with pytest.raises(ValueError, match="Expected a cuda device"):
76+
current_platform.set_device(torch.device('cpu'))
77+
78+
79+
if __name__ == "__main__":
80+
pytest.main([__file__, "-v"])

‎vllm/platforms/cuda.py‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ def supported_dtypes(self) -> list[torch.dtype]:
7171
# though vLLM doesn't support these GPUs.
7272
return [torch.float32]
7373

74+
@classmethod
75+
def set_device(cls, device: torch.device) -> None:
76+
"""
77+
Set the device for the current platform.
78+
"""
79+
super().set_device(device)
80+
# With this trick we can force the device to be set eagerly
81+
# see https://github.com/pytorch/pytorch/issues/155668
82+
# for why and when it is needed
83+
_ = torch.zeros(1, device=device)
84+
7485
@classmethod
7586
def get_device_capability(cls,
7687
device_id: int = 0

‎vllm/platforms/interface.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,13 @@ def seed_everything(cls, seed: Optional[int] = None) -> None:
298298
np.random.seed(seed)
299299
torch.manual_seed(seed)
300300

301+
@classmethod
302+
def set_device(cls, device: torch.device) -> None:
303+
"""
304+
Set the device for the current platform.
305+
"""
306+
torch.cuda.set_device(device)
307+
301308
@classmethod
302309
def pre_register_and_update(cls,
303310
parser: Optional[FlexibleArgumentParser] = None

0 commit comments

Comments
 (0)