@@ -218,6 +218,52 @@ def test_is_colored_output_on(self):
218
218
color = trtorch .logging .get_is_colored_output_on ()
219
219
self .assertTrue (color )
220
220
221
+ class TestDevice (unittest .TestCase ):
222
+
223
+ def test_from_string_constructor (self ):
224
+ device = trtorch .Device ("cuda:0" )
225
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
226
+ self .assertEqual (device .gpu_id , 0 )
227
+
228
+ device = trtorch .Device ("gpu:1" )
229
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
230
+ self .assertEqual (device .gpu_id , 1 )
231
+
232
+ def test_from_string_constructor_dla (self ):
233
+ device = trtorch .Device ("dla:0" )
234
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
235
+ self .assertEqual (device .gpu_id , 0 )
236
+ self .assertEqual (device .dla_core , 0 )
237
+
238
+ device = trtorch .Device ("dla:1" , allow_gpu_fallback = True )
239
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
240
+ self .assertEqual (device .gpu_id , 0 )
241
+ self .assertEqual (device .dla_core , 1 )
242
+ self .assertEqual (device .allow_gpu_fallback , True )
243
+
244
+ def test_kwargs_gpu (self ):
245
+ device = trtorch .Device (gpu_id = 0 )
246
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
247
+ self .assertEqual (device .gpu_id , 0 )
248
+
249
+ def test_kwargs_dla_and_settings (self ):
250
+ device = trtorch .Device (dla_core = 1 , allow_gpu_fallback = False )
251
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
252
+ self .assertEqual (device .gpu_id , 0 )
253
+ self .assertEqual (device .dla_core , 1 )
254
+ self .assertEqual (device .allow_gpu_fallback , False )
255
+
256
+ device = trtorch .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
257
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
258
+ self .assertEqual (device .gpu_id , 1 )
259
+ self .assertEqual (device .dla_core , 0 )
260
+ self .assertEqual (device .allow_gpu_fallback , True )
261
+
262
+ def test_from_torch (self ):
263
+ device = trtorch .Device ._from_torch_device (torch .device ("cuda:0" ))
264
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
265
+ self .assertEqual (device .gpu_id , 0 )
266
+
221
267
222
268
def test_suite ():
223
269
suite = unittest .TestSuite ()
@@ -231,6 +277,7 @@ def test_suite():
231
277
suite .addTest (
232
278
TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
233
279
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
280
+ suite .addTest (unittest .makeSuite (TestDevice ))
234
281
235
282
return suite
236
283
0 commit comments