Skip to content

Commit ebb906f

Browse files
authored
Merge pull request #635 from NVIDIA/fix_device
fix(//py): Fix trtorch.Device alternate contructor options
2 parents 0a39189 + fa08311 commit ebb906f

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

py/trtorch/Device.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from trtorch import _types
4-
import logging
4+
import trtorch.logging
55
import trtorch._C
66

77
import warnings
@@ -54,23 +54,27 @@ def __init__(self, *args, **kwargs):
5454
else:
5555
self.dla_core = id
5656
self.gpu_id = 0
57-
logging.log(logging.log.Level.Warning,
58-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
57+
trtorch.logging.log(trtorch.logging.Level.Warning,
58+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
5959

6060
elif len(args) == 0:
61-
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
61+
if "gpu_id" in kwargs or "dla_core" in kwargs:
6262
if "dla_core" in kwargs:
6363
self.device_type = _types.DeviceType.DLA
6464
self.dla_core = kwargs["dla_core"]
6565
if "gpu_id" in kwargs:
6666
self.gpu_id = kwargs["gpu_id"]
6767
else:
6868
self.gpu_id = 0
69-
logging.log(logging.log.Level.Warning,
70-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
69+
trtorch.logging.log(trtorch.logging.Level.Warning,
70+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
7171
else:
7272
self.gpu_id = kwargs["gpu_id"]
73-
self.device_type == _types.DeviceType.GPU
73+
self.device_type = _types.DeviceType.GPU
74+
else:
75+
raise ValueError(
76+
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
77+
)
7478

7579
else:
7680
raise ValueError(
@@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs):
8084
if "allow_gpu_fallback" in kwargs:
8185
if not isinstance(kwargs["allow_gpu_fallback"], bool):
8286
raise TypeError("allow_gpu_fallback must be a bool")
87+
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]
8388

8489
def __str__(self) -> str:
8590
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \

tests/py/test_api.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,53 @@ def test_is_colored_output_on(self):
219219
self.assertTrue(color)
220220

221221

222+
class TestDevice(unittest.TestCase):
223+
224+
def test_from_string_constructor(self):
225+
device = trtorch.Device("cuda:0")
226+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
227+
self.assertEqual(device.gpu_id, 0)
228+
229+
device = trtorch.Device("gpu:1")
230+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
231+
self.assertEqual(device.gpu_id, 1)
232+
233+
def test_from_string_constructor_dla(self):
234+
device = trtorch.Device("dla:0")
235+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
236+
self.assertEqual(device.gpu_id, 0)
237+
self.assertEqual(device.dla_core, 0)
238+
239+
device = trtorch.Device("dla:1", allow_gpu_fallback=True)
240+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
241+
self.assertEqual(device.gpu_id, 0)
242+
self.assertEqual(device.dla_core, 1)
243+
self.assertEqual(device.allow_gpu_fallback, True)
244+
245+
def test_kwargs_gpu(self):
246+
device = trtorch.Device(gpu_id=0)
247+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
248+
self.assertEqual(device.gpu_id, 0)
249+
250+
def test_kwargs_dla_and_settings(self):
251+
device = trtorch.Device(dla_core=1, allow_gpu_fallback=False)
252+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
253+
self.assertEqual(device.gpu_id, 0)
254+
self.assertEqual(device.dla_core, 1)
255+
self.assertEqual(device.allow_gpu_fallback, False)
256+
257+
device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
258+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
259+
self.assertEqual(device.gpu_id, 1)
260+
self.assertEqual(device.dla_core, 0)
261+
self.assertEqual(device.allow_gpu_fallback, True)
262+
263+
def test_from_torch(self):
264+
device = trtorch.Device._from_torch_device(torch.device("cuda:0"))
265+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
266+
self.assertEqual(device.gpu_id, 0)
267+
268+
222269
def test_suite():
223270
suite = unittest.TestSuite()
224271
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
@@ -231,6 +278,7 @@ def test_suite():
231278
suite.addTest(
232279
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
233280
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
281+
suite.addTest(unittest.makeSuite(TestDevice))
234282

235283
return suite
236284

0 commit comments

Comments
 (0)