@@ -219,6 +219,53 @@ def test_is_colored_output_on(self):
219
219
self .assertTrue (color )
220
220
221
221
222
+ class TestDevice (unittest .TestCase ):
223
+
224
+ def test_from_string_constructor (self ):
225
+ device = trtorch .Device ("cuda:0" )
226
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
227
+ self .assertEqual (device .gpu_id , 0 )
228
+
229
+ device = trtorch .Device ("gpu:1" )
230
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
231
+ self .assertEqual (device .gpu_id , 1 )
232
+
233
+ def test_from_string_constructor_dla (self ):
234
+ device = trtorch .Device ("dla:0" )
235
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
236
+ self .assertEqual (device .gpu_id , 0 )
237
+ self .assertEqual (device .dla_core , 0 )
238
+
239
+ device = trtorch .Device ("dla:1" , allow_gpu_fallback = True )
240
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
241
+ self .assertEqual (device .gpu_id , 0 )
242
+ self .assertEqual (device .dla_core , 1 )
243
+ self .assertEqual (device .allow_gpu_fallback , True )
244
+
245
+ def test_kwargs_gpu (self ):
246
+ device = trtorch .Device (gpu_id = 0 )
247
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
248
+ self .assertEqual (device .gpu_id , 0 )
249
+
250
+ def test_kwargs_dla_and_settings (self ):
251
+ device = trtorch .Device (dla_core = 1 , allow_gpu_fallback = False )
252
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
253
+ self .assertEqual (device .gpu_id , 0 )
254
+ self .assertEqual (device .dla_core , 1 )
255
+ self .assertEqual (device .allow_gpu_fallback , False )
256
+
257
+ device = trtorch .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
258
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
259
+ self .assertEqual (device .gpu_id , 1 )
260
+ self .assertEqual (device .dla_core , 0 )
261
+ self .assertEqual (device .allow_gpu_fallback , True )
262
+
263
+ def test_from_torch (self ):
264
+ device = trtorch .Device ._from_torch_device (torch .device ("cuda:0" ))
265
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
266
+ self .assertEqual (device .gpu_id , 0 )
267
+
268
+
222
269
def test_suite ():
223
270
suite = unittest .TestSuite ()
224
271
suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
@@ -231,6 +278,7 @@ def test_suite():
231
278
suite .addTest (
232
279
TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
233
280
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
281
+ suite .addTest (unittest .makeSuite (TestDevice ))
234
282
235
283
return suite
236
284
0 commit comments