@@ -2102,6 +2102,78 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
2102
2102
gm_retrace = acc_tracer .trace (gm , [a ])
2103
2103
self .assertTrue (torch .equal (m (a ), gm_retrace (a )))
2104
2104
2105
+ def test_index_select (self ):
2106
+ class TestModule (nn .Module ):
2107
+ def __init__ (self , dim , index ):
2108
+ super ().__init__ ()
2109
+ self ._dim = dim
2110
+ self ._index = index
2111
+
2112
+ def forward (self , a : torch .Tensor ) -> torch .Tensor :
2113
+ return torch .index_select (a , self ._dim , self ._index )
2114
+
2115
+ dim = 0
2116
+ index = torch .tensor ([1 , 0 ])
2117
+ m = TestModule (dim , index )
2118
+ _input = [torch .randn (2 , 3 ), torch .randn (2 , 3 )]
2119
+ traced = acc_tracer .trace (m , _input )
2120
+
2121
+ ph = index = index_select = None
2122
+
2123
+ for node in traced .graph .nodes :
2124
+ if node .op == "placeholder" :
2125
+ self .assertEqual (str (node .target ), "a" )
2126
+ ph = node
2127
+ elif node .op == "call_function" and node .target == acc_ops .index_select :
2128
+ self .assertTrue (node .kwargs ["input" ] == ph )
2129
+ self .assertTrue (node .kwargs ["index" ] == index )
2130
+ self .assertTrue (node .kwargs ["dim" ] == dim )
2131
+ index_select = node
2132
+ elif node .op == "output" :
2133
+ self .assertEqual (index_select , node .args [0 ])
2134
+ elif node .op == "get_attr" :
2135
+ # There only be one™ const node
2136
+ self .assertTrue (index is None )
2137
+ index = node
2138
+ else :
2139
+ self .fail (f"Unexpected node: { node .format_node ()} " )
2140
+
2141
+ def test_gather (self ):
2142
+ class TestModule (nn .Module ):
2143
+ def __init__ (self , dim , index ):
2144
+ super ().__init__ ()
2145
+ self ._dim = dim
2146
+ self ._index = index
2147
+
2148
+ def forward (self , a : torch .Tensor ) -> torch .Tensor :
2149
+ return torch .gather (a , self ._dim , self ._index )
2150
+
2151
+ dim = 0
2152
+ index = torch .tensor ([[1 , 0 ], [0 , 1 ]])
2153
+ m = TestModule (dim , index )
2154
+ _input = [torch .randn (2 , 3 ), torch .randn (2 , 3 )]
2155
+ traced = acc_tracer .trace (m , _input )
2156
+
2157
+ ph = index = gather = None
2158
+
2159
+ for node in traced .graph .nodes :
2160
+ if node .op == "placeholder" :
2161
+ self .assertEqual (str (node .target ), "a" )
2162
+ ph = node
2163
+ elif node .op == "call_function" and node .target == acc_ops .gather :
2164
+ self .assertTrue (node .kwargs ["input" ] == ph )
2165
+ self .assertTrue (node .kwargs ["index" ] == index )
2166
+ self .assertTrue (node .kwargs ["dim" ] == dim )
2167
+ gather = node
2168
+ elif node .op == "output" :
2169
+ self .assertEqual (gather , node .args [0 ])
2170
+ elif node .op == "get_attr" :
2171
+ # There only be one™ const node
2172
+ self .assertTrue (index is None )
2173
+ index = node
2174
+ else :
2175
+ self .fail (f"Unexpected node: { node .format_node ()} " )
2176
+
2105
2177
def test_all_acc_ops_registered (self ):
2106
2178
self .assertEqual (
2107
2179
acc_normalizer ._acc_ops ,
@@ -2203,5 +2275,7 @@ def test_all_acc_ops_registered(self):
2203
2275
acc_ops .eq ,
2204
2276
acc_ops .gt ,
2205
2277
acc_ops .le ,
2278
+ acc_ops .gather ,
2279
+ acc_ops .index_select ,
2206
2280
},
2207
2281
)
0 commit comments