Skip to content

Commit f51f92e

Browse files
committed
fix(//py): Fix trtorch.Device alternate contructor options
There were issues setting fields of trtorch.Device via kwargs, this patch should resolve those and verify that they work Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0a39189 commit f51f92e

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

py/trtorch/Device.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from trtorch import _types
4-
import logging
4+
import trtorch.logging
55
import trtorch._C
66

77
import warnings
@@ -54,23 +54,27 @@ def __init__(self, *args, **kwargs):
5454
else:
5555
self.dla_core = id
5656
self.gpu_id = 0
57-
logging.log(logging.log.Level.Warning,
58-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
57+
trtorch.logging.log(trtorch.logging.Level.Warning,
58+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
5959

6060
elif len(args) == 0:
61-
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
61+
if "gpu_id" in kwargs or "dla_core" in kwargs:
6262
if "dla_core" in kwargs:
6363
self.device_type = _types.DeviceType.DLA
6464
self.dla_core = kwargs["dla_core"]
6565
if "gpu_id" in kwargs:
6666
self.gpu_id = kwargs["gpu_id"]
6767
else:
6868
self.gpu_id = 0
69-
logging.log(logging.log.Level.Warning,
70-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
69+
trtorch.logging.log(trtorch.logging.Level.Warning,
70+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
7171
else:
7272
self.gpu_id = kwargs["gpu_id"]
73-
self.device_type == _types.DeviceType.GPU
73+
self.device_type = _types.DeviceType.GPU
74+
else:
75+
raise ValueError(
76+
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
77+
)
7478

7579
else:
7680
raise ValueError(
@@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs):
8084
if "allow_gpu_fallback" in kwargs:
8185
if not isinstance(kwargs["allow_gpu_fallback"], bool):
8286
raise TypeError("allow_gpu_fallback must be a bool")
87+
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]
8388

8489
def __str__(self) -> str:
8590
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \

tests/py/test_api.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,52 @@ def test_is_colored_output_on(self):
218218
color = trtorch.logging.get_is_colored_output_on()
219219
self.assertTrue(color)
220220

221+
class TestDevice(unittest.TestCase):
222+
223+
def test_from_string_constructor(self):
224+
device = trtorch.Device("cuda:0")
225+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
226+
self.assertEqual(device.gpu_id, 0)
227+
228+
device = trtorch.Device("gpu:1")
229+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
230+
self.assertEqual(device.gpu_id, 1)
231+
232+
def test_from_string_constructor_dla(self):
233+
device = trtorch.Device("dla:0")
234+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
235+
self.assertEqual(device.gpu_id, 0)
236+
self.assertEqual(device.dla_core, 0)
237+
238+
device = trtorch.Device("dla:1", allow_gpu_fallback=True)
239+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
240+
self.assertEqual(device.gpu_id, 0)
241+
self.assertEqual(device.dla_core, 1)
242+
self.assertEqual(device.allow_gpu_fallback, True)
243+
244+
def test_kwargs_gpu(self):
245+
device = trtorch.Device(gpu_id=0)
246+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
247+
self.assertEqual(device.gpu_id, 0)
248+
249+
def test_kwargs_dla_and_settings(self):
250+
device = trtorch.Device(dla_core=1, allow_gpu_fallback=False)
251+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
252+
self.assertEqual(device.gpu_id, 0)
253+
self.assertEqual(device.dla_core, 1)
254+
self.assertEqual(device.allow_gpu_fallback, False)
255+
256+
device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
257+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
258+
self.assertEqual(device.gpu_id, 1)
259+
self.assertEqual(device.dla_core, 0)
260+
self.assertEqual(device.allow_gpu_fallback, True)
261+
262+
def test_from_torch(self):
263+
device = trtorch.Device._from_torch_device(torch.device("cuda:0"))
264+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
265+
self.assertEqual(device.gpu_id, 0)
266+
221267

222268
def test_suite():
223269
suite = unittest.TestSuite()
@@ -231,6 +277,7 @@ def test_suite():
231277
suite.addTest(
232278
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
233279
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
280+
suite.addTest(unittest.makeSuite(TestDevice))
234281

235282
return suite
236283

0 commit comments

Comments
 (0)