Skip to content

fix(//py): Fix trtorch.Device alternate contructor options #638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions py/trtorch/Device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from trtorch import _types
import logging
import trtorch.logging
import trtorch._C

import warnings
Expand Down Expand Up @@ -54,23 +54,27 @@ def __init__(self, *args, **kwargs):
else:
self.dla_core = id
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
trtorch.logging.log(trtorch.logging.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")

elif len(args) == 0:
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
if "gpu_id" in kwargs or "dla_core" in kwargs:
if "dla_core" in kwargs:
self.device_type = _types.DeviceType.DLA
self.dla_core = kwargs["dla_core"]
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]
else:
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
trtorch.logging.log(trtorch.logging.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
else:
self.gpu_id = kwargs["gpu_id"]
self.device_type == _types.DeviceType.GPU
self.device_type = _types.DeviceType.GPU
else:
raise ValueError(
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
)

else:
raise ValueError(
Expand All @@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs):
if "allow_gpu_fallback" in kwargs:
if not isinstance(kwargs["allow_gpu_fallback"], bool):
raise TypeError("allow_gpu_fallback must be a bool")
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]

def __str__(self) -> str:
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \
Expand Down
48 changes: 48 additions & 0 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,53 @@ def test_is_colored_output_on(self):
self.assertTrue(color)


class TestDevice(unittest.TestCase):

def test_from_string_constructor(self):
device = trtorch.Device("cuda:0")
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)

device = trtorch.Device("gpu:1")
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 1)

def test_from_string_constructor_dla(self):
device = trtorch.Device("dla:0")
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 0)

device = trtorch.Device("dla:1", allow_gpu_fallback=True)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 1)
self.assertEqual(device.allow_gpu_fallback, True)

def test_kwargs_gpu(self):
device = trtorch.Device(gpu_id=0)
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)

def test_kwargs_dla_and_settings(self):
device = trtorch.Device(dla_core=1, allow_gpu_fallback=False)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 1)
self.assertEqual(device.allow_gpu_fallback, False)

device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 1)
self.assertEqual(device.dla_core, 0)
self.assertEqual(device.allow_gpu_fallback, True)

def test_from_torch(self):
device = trtorch.Device._from_torch_device(torch.device("cuda:0"))
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
Expand All @@ -242,6 +289,7 @@ def test_suite():
suite.addTest(
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
suite.addTest(unittest.makeSuite(TestDevice))

return suite

Expand Down