1
+ import unittest
2
+ import torch_tensorrt as torchtrt
3
+ import torch
4
+ import torchvision .models as models
5
+ import copy
6
+ from typing import Dict
7
+
8
+ class TestDevice (unittest .TestCase ):
9
+
10
+ def test_from_string_constructor (self ):
11
+ device = torchtrt .Device ("cuda:0" )
12
+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
13
+ self .assertEqual (device .gpu_id , 0 )
14
+
15
+ device = torchtrt .Device ("gpu:1" )
16
+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
17
+ self .assertEqual (device .gpu_id , 1 )
18
+
19
+ def test_from_string_constructor_dla (self ):
20
+ device = torchtrt .Device ("dla:0" )
21
+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
22
+ self .assertEqual (device .gpu_id , 0 )
23
+ self .assertEqual (device .dla_core , 0 )
24
+
25
+ device = torchtrt .Device ("dla:1" , allow_gpu_fallback = True )
26
+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
27
+ self .assertEqual (device .gpu_id , 0 )
28
+ self .assertEqual (device .dla_core , 1 )
29
+ self .assertEqual (device .allow_gpu_fallback , True )
30
+
31
+ def test_kwargs_gpu (self ):
32
+ device = torchtrt .Device (gpu_id = 0 )
33
+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
34
+ self .assertEqual (device .gpu_id , 0 )
35
+
36
+ def test_kwargs_dla_and_settings (self ):
37
+ device = torchtrt .Device (dla_core = 1 , allow_gpu_fallback = False )
38
+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
39
+ self .assertEqual (device .gpu_id , 0 )
40
+ self .assertEqual (device .dla_core , 1 )
41
+ self .assertEqual (device .allow_gpu_fallback , False )
42
+
43
+ device = torchtrt .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
44
+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
45
+ self .assertEqual (device .gpu_id , 1 )
46
+ self .assertEqual (device .dla_core , 0 )
47
+ self .assertEqual (device .allow_gpu_fallback , True )
48
+
49
+ def test_from_torch (self ):
50
+ device = torchtrt .Device ._from_torch_device (torch .device ("cuda:0" ))
51
+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
52
+ self .assertEqual (device .gpu_id , 0 )
53
+
54
+
55
+ class TestInput (unittest .TestCase ):
56
+
57
+ def _verify_correctness (self , struct : torchtrt .Input , target : Dict ) -> bool :
58
+ internal = struct ._to_internal ()
59
+
60
+ list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
61
+
62
+ eq = lambda a , b : a == b
63
+
64
+ def field_is_correct (field , equal_fn , a1 , a2 ):
65
+ equal = equal_fn (a1 , a2 )
66
+ if not equal :
67
+ print ("\n Field {} is incorrect: {} != {}" .format (field , a1 , a2 ))
68
+ return equal
69
+
70
+ min_ = field_is_correct ("min" , list_eq , internal .min , target ["min" ])
71
+ opt_ = field_is_correct ("opt" , list_eq , internal .opt , target ["opt" ])
72
+ max_ = field_is_correct ("max" , list_eq , internal .max , target ["max" ])
73
+ is_dynamic_ = field_is_correct ("is_dynamic" , eq , internal .input_is_dynamic , target ["input_is_dynamic" ])
74
+ explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype ,
75
+ target ["explicit_set_dtype" ])
76
+ dtype_ = field_is_correct ("dtype" , eq , int (internal .dtype ), int (target ["dtype" ]))
77
+ format_ = field_is_correct ("format" , eq , int (internal .format ), int (target ["format" ]))
78
+
79
+ return all ([min_ , opt_ , max_ , is_dynamic_ , explicit_set_dtype_ , dtype_ , format_ ])
80
+
81
+ def test_infer_from_example_tensor (self ):
82
+ shape = [1 , 3 , 255 , 255 ]
83
+ target = {
84
+ "min" : shape ,
85
+ "opt" : shape ,
86
+ "max" : shape ,
87
+ "input_is_dynamic" : False ,
88
+ "dtype" : torchtrt .dtype .half ,
89
+ "format" : torchtrt .TensorFormat .contiguous ,
90
+ "explicit_set_dtype" : True
91
+ }
92
+
93
+ example_tensor = torch .randn (shape ).half ()
94
+ i = torchtrt .Input ._from_tensor (example_tensor )
95
+ self .assertTrue (self ._verify_correctness (i , target ))
96
+
97
+ def test_static_shape (self ):
98
+ shape = [1 , 3 , 255 , 255 ]
99
+ target = {
100
+ "min" : shape ,
101
+ "opt" : shape ,
102
+ "max" : shape ,
103
+ "input_is_dynamic" : False ,
104
+ "dtype" : torchtrt .dtype .unknown ,
105
+ "format" : torchtrt .TensorFormat .contiguous ,
106
+ "explicit_set_dtype" : False
107
+ }
108
+
109
+ i = torchtrt .Input (shape )
110
+ self .assertTrue (self ._verify_correctness (i , target ))
111
+
112
+ i = torchtrt .Input (tuple (shape ))
113
+ self .assertTrue (self ._verify_correctness (i , target ))
114
+
115
+ i = torchtrt .Input (torch .randn (shape ).shape )
116
+ self .assertTrue (self ._verify_correctness (i , target ))
117
+
118
+ i = torchtrt .Input (shape = shape )
119
+ self .assertTrue (self ._verify_correctness (i , target ))
120
+
121
+ i = torchtrt .Input (shape = tuple (shape ))
122
+ self .assertTrue (self ._verify_correctness (i , target ))
123
+
124
+ i = torchtrt .Input (shape = torch .randn (shape ).shape )
125
+ self .assertTrue (self ._verify_correctness (i , target ))
126
+
127
+ def test_data_type (self ):
128
+ shape = [1 , 3 , 255 , 255 ]
129
+ target = {
130
+ "min" : shape ,
131
+ "opt" : shape ,
132
+ "max" : shape ,
133
+ "input_is_dynamic" : False ,
134
+ "dtype" : torchtrt .dtype .half ,
135
+ "format" : torchtrt .TensorFormat .contiguous ,
136
+ "explicit_set_dtype" : True
137
+ }
138
+
139
+ i = torchtrt .Input (shape , dtype = torchtrt .dtype .half )
140
+ self .assertTrue (self ._verify_correctness (i , target ))
141
+
142
+ i = torchtrt .Input (shape , dtype = torch .half )
143
+ self .assertTrue (self ._verify_correctness (i , target ))
144
+
145
+ def test_tensor_format (self ):
146
+ shape = [1 , 3 , 255 , 255 ]
147
+ target = {
148
+ "min" : shape ,
149
+ "opt" : shape ,
150
+ "max" : shape ,
151
+ "input_is_dynamic" : False ,
152
+ "dtype" : torchtrt .dtype .unknown ,
153
+ "format" : torchtrt .TensorFormat .channels_last ,
154
+ "explicit_set_dtype" : False
155
+ }
156
+
157
+ i = torchtrt .Input (shape , format = torchtrt .TensorFormat .channels_last )
158
+ self .assertTrue (self ._verify_correctness (i , target ))
159
+
160
+ i = torchtrt .Input (shape , format = torch .channels_last )
161
+ self .assertTrue (self ._verify_correctness (i , target ))
162
+
163
+ def test_dynamic_shape (self ):
164
+ min_shape = [1 , 3 , 128 , 128 ]
165
+ opt_shape = [1 , 3 , 256 , 256 ]
166
+ max_shape = [1 , 3 , 512 , 512 ]
167
+ target = {
168
+ "min" : min_shape ,
169
+ "opt" : opt_shape ,
170
+ "max" : max_shape ,
171
+ "input_is_dynamic" : True ,
172
+ "dtype" : torchtrt .dtype .unknown ,
173
+ "format" : torchtrt .TensorFormat .contiguous ,
174
+ "explicit_set_dtype" : False
175
+ }
176
+
177
+ i = torchtrt .Input (min_shape = min_shape , opt_shape = opt_shape , max_shape = max_shape )
178
+ self .assertTrue (self ._verify_correctness (i , target ))
179
+
180
+ i = torchtrt .Input (min_shape = tuple (min_shape ), opt_shape = tuple (opt_shape ), max_shape = tuple (max_shape ))
181
+ self .assertTrue (self ._verify_correctness (i , target ))
182
+
183
+ tensor_shape = lambda shape : torch .randn (shape ).shape
184
+ i = torchtrt .Input (min_shape = tensor_shape (min_shape ),
185
+ opt_shape = tensor_shape (opt_shape ),
186
+ max_shape = tensor_shape (max_shape ))
187
+ self .assertTrue (self ._verify_correctness (i , target ))
188
+
189
+ if __name__ == "__main__" :
190
+ unittest .main ()
0 commit comments