1
1
import unittest
2
2
3
3
import torch
4
+ import torchdynamo
5
+ import torchvision
6
+
7
+ from functorch import make_fx as make_fx_pk
8
+ from functorch .experimental import functionalize
4
9
from fx2trt_oss .tracer .dispatch_tracer .tracer import make_fx
10
+ from torch .library import Library
11
+ from torchdynamo .optimizations .normalize import normalize_ir
12
+ from torchdynamo .optimizations .python_key import fake_signature
5
13
6
14
torch .manual_seed (0 )
7
15
16
+ wrap_lib = Library ("wrap" , "DEF" )
17
+ """
18
+ There are two methods for setting leaf_module. leaf(op registeration) and leaf(override call_module)
19
+ Only leaf(op registeration) can work together with functionalize.
20
+ If you do not need funcitonalize, you can choose any of the leaf module methods.
21
+
22
+ Test coverage:
23
+ PythonkeyTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration)
24
+
25
+ DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration)
26
+ DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module)
27
+ DispatchTracerTest.test_non_tensor_input: dispatch tracer
28
+ DispatchTracerTest.test_resnet18: dispatch tracer
29
+ DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize
30
+ DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize
31
+ """
32
+
33
+
34
+ class PythonkeyTracerTest (unittest .TestCase ):
35
+ def test_leaf_operator_reg (self ):
36
+ class Leaf (torch .nn .Module ):
37
+ def forward (self , x , y ):
38
+ return x + y + torch .nn .Parameter (torch .ones (5 ))
39
+
40
+ leaf = Leaf ()
41
+ wrap_lib .define ("wrapped_foo(Tensor x, Tensor y) -> Tensor" )
42
+ wrap_lib .impl ("wrapped_foo" , leaf , "CPU" )
43
+
44
+ class Bar (torch .nn .Module ):
45
+ def __init__ (self ):
46
+ super (Bar , self ).__init__ ()
47
+ self .foo = torch .ops .wrap .wrapped_foo
48
+ self .other = torch .nn .Parameter (torch .ones (5 ))
49
+
50
+ def forward (self , x , y ):
51
+ x = self .foo (x , y )
52
+ x = x + self .other
53
+ return x
54
+
55
+ mod = Bar ()
56
+
57
+ def f (x , y ):
58
+ return mod (x , y )
59
+
60
+ gm = make_fx_pk (functionalize (f ))(torch .ones (5 ), torch .ones (5 ))
61
+ inputs = [torch .ones (5 ) + 5 , torch .ones (5 ) + 8 ]
62
+ output = gm (* inputs )
63
+ ref_output = f (* inputs )
64
+ torch .testing .assert_close (output , ref_output )
65
+
8
66
9
67
class DispatchTracerTest (unittest .TestCase ):
10
- def test_leaf_module_list (self ):
11
- class TestModule (torch .nn .Module ):
68
+ def test_leaf_operator_reg (self ):
69
+ class Leaf (torch .nn .Module ):
70
+ def forward (self , x , y ):
71
+ return x + y + torch .nn .Parameter (torch .ones (5 ))
72
+
73
+ leaf = Leaf ()
74
+ wrap_lib .define ("wrapped_leaf(Tensor x, Tensor y) -> Tensor" )
75
+ wrap_lib .impl ("wrapped_leaf" , leaf , "CPU" )
76
+
77
+ class Bar (torch .nn .Module ):
78
+ def __init__ (self ):
79
+ super (Bar , self ).__init__ ()
80
+ self .leaf = torch .ops .wrap .wrapped_leaf
81
+ self .other = torch .nn .Parameter (torch .ones (5 ))
82
+
83
+ def forward (self , x , y ):
84
+ x = self .leaf (x , y )
85
+ x = x + self .other
86
+ return x
87
+
88
+ mod = Bar ()
89
+
90
+ def f (x , y ):
91
+ return mod (x , y )
92
+
93
+ gm = make_fx (functionalize (f ))(torch .ones (5 ), torch .ones (5 ))
94
+ inputs = [torch .ones (5 ) + 5 , torch .ones (5 ) + 8 ]
95
+ output = gm (* inputs )
96
+ ref_output = f (* inputs )
97
+ torch .testing .assert_close (output , ref_output )
98
+ # through the op registration method, the module is defined in a call_function
99
+ call_function_node = None
100
+ for node in gm .graph .nodes :
101
+ if (
102
+ node .op == "call_function"
103
+ and node .target == torch .ops .wrap .wrapped_leaf
104
+ ):
105
+ call_function_node = node
106
+ self .assertIsNotNone (call_function_node )
107
+
108
+ def test_leaf (self ):
109
+ class TestModuleLeaf (torch .nn .Module ):
12
110
def __init__ (self ):
13
111
super ().__init__ ()
14
112
self .conv = torch .nn .Conv2d (3 , 10 , 1 )
15
- self .relu = torch .nn .ReLU ()
113
+ self .relu = torch .nn .ReLU (inplace = True )
16
114
17
115
def forward (self , x ):
18
116
x = self .conv (x )
19
117
return self .relu (x )
20
118
119
+ class TestModule (torch .nn .Module ):
120
+ def __init__ (self ):
121
+ super ().__init__ ()
122
+
123
+ self .relu = torch .nn .ReLU (inplace = True )
124
+ self .leaf = TestModuleLeaf ()
125
+
126
+ def forward (self , x ):
127
+ x = self .leaf (x )
128
+ return self .relu (x )
129
+
21
130
mod = TestModule ()
22
131
23
132
def f (x ):
24
133
return mod (x )
25
134
26
135
a = torch .randn (1 , 3 , 1 , 1 )
27
136
ref_output = f (a )
28
- func = make_fx (f , leaf_module_list = {"torch.nn.modules.activation.ReLU " })
137
+ func = make_fx (f , leaf_module_list = {"test_dispatch_tracer.TestModuleLeaf " })
29
138
gm = func (a )
30
139
output = gm (a )
31
140
torch .testing .assert_close (output , ref_output )
@@ -36,17 +145,90 @@ def f(x):
36
145
if node .op == "call_module" :
37
146
call_module_node = node
38
147
self .assertIsNotNone (call_module_node )
39
- self .assertEqual (call_module_node .target , "ReLU_0 " )
148
+ self .assertEqual (call_module_node .target , "TestModuleLeaf_0 " )
40
149
41
150
def test_non_tensor_input (self ):
42
151
def foo (x ):
43
152
a = x ["a" ]
44
153
b = x ["b" ]
45
154
return a + b
46
155
47
- x = {"a" : torch .randn (1 ), "b" : torch .randn (1 )}
156
+ x = {"a" : torch .randn (2 , 2 ), "b" : torch .randn (2 , 2 )}
48
157
ref_output = foo (x )
49
158
func = make_fx (foo )
50
159
gm = func (x )
51
160
output = gm (x )
52
161
torch .testing .assert_close (output , ref_output )
162
+
163
+ def test_resnet18 (self ):
164
+ mod = torchvision .models .resnet18 (pretrained = False )
165
+
166
+ def f (x ):
167
+ return mod (x )
168
+
169
+ a = torch .randn (1 , 3 , 224 , 224 )
170
+ ref_output = f (a )
171
+ gm = make_fx (f )(a )
172
+ output = gm (a )
173
+ torch .testing .assert_close (output , ref_output )
174
+
175
+ def test_reference_copy (self ):
176
+ class TestModule (torch .nn .Module ):
177
+ def __init__ (self ):
178
+ super ().__init__ ()
179
+
180
+ def forward (self , x , y ):
181
+ y [:, 0 ] = x [:, 0 ]
182
+ return y
183
+
184
+ mod = TestModule ()
185
+
186
+ def f (x , y ):
187
+ return mod (x , y )
188
+
189
+ a = torch .ones (2 , 2 ) + 2
190
+ b = torch .ones (2 , 2 )
191
+ b_copy = torch .ones (2 , 2 )
192
+ ref_output = f (a , b )
193
+ gm = make_fx (functionalize (f ))(a , b )
194
+ output = gm (a , b_copy )
195
+ torch .testing .assert_close (output , ref_output )
196
+
197
+ def test_reference_copy_torchdynamo (self ):
198
+ class TestModule (torch .nn .Module ):
199
+ def __init__ (self ):
200
+ super ().__init__ ()
201
+ self .relu = torch .nn .ReLU (inplace = True )
202
+
203
+ def forward (self , x , y ):
204
+ y = y + 3
205
+ y = self .relu (y )
206
+ y [:, 0 ] = x [:, 0 ]
207
+ return y
208
+
209
+ mod = TestModule ()
210
+
211
+ def f (x , y ):
212
+ return mod (x , y )
213
+
214
+ a = torch .ones (2 , 2 ) + 2
215
+ b = torch .ones (2 , 2 )
216
+ inputs = [a , b ]
217
+ ref_output = f (* inputs )
218
+
219
+ def compile_dispatch (gm , example_inputs ):
220
+ # after normalization, relu in-place is removed
221
+ gm = normalize_ir (gm , example_inputs )
222
+ # dispatch tracer
223
+ nargs = len (example_inputs )
224
+ gm = make_fx (functionalize (fake_signature (gm , nargs )))(* example_inputs )
225
+ return gm
226
+
227
+ optimize_ctx = torchdynamo .optimize (
228
+ compile_dispatch ,
229
+ nopython = True ,
230
+ )
231
+
232
+ with optimize_ctx :
233
+ output = mod (* inputs )
234
+ torch .testing .assert_close (output , ref_output )
0 commit comments