@@ -2097,29 +2097,29 @@ def infer_shape(self, fgraph, node, input_shapes):
2097
2097
# sequences
2098
2098
seqs_shape = [x [1 :] for x in input_shapes [1 : 1 + info .n_seqs ]]
2099
2099
# We disable extra infer_shape for now. See gh-3765.
2100
- extra_infer_shape = False
2100
+ # if extra_infer_shape:
2101
+ # inner_seqs = self.inputs[: info.n_seqs]
2102
+ # outer_seqs = node.inputs[1 : 1 + info.n_seqs]
2103
+ # for in_s, out_s in zip(inner_seqs, outer_seqs):
2104
+ # out_equivalent[in_s] = out_s[0]
2105
+ #
2106
+ # # mit_mot, mit_sot, sit_sot
2107
+ # outer_inp_idx = 1 + info.n_seqs
2108
+ # inner_inp_idx = info.n_seqs
2109
+ # else:
2110
+ # outer_inp_idx = 0
2111
+ outer_inp_idx = 0
2101
2112
2102
- if extra_infer_shape :
2103
- inner_seqs = self .inputs [: info .n_seqs ]
2104
- outer_seqs = node .inputs [1 : 1 + info .n_seqs ]
2105
- for in_s , out_s in zip (inner_seqs , outer_seqs ):
2106
- out_equivalent [in_s ] = out_s [0 ]
2107
-
2108
- # mit_mot, mit_sot, sit_sot
2109
- outer_inp_idx = 1 + info .n_seqs
2110
- inner_inp_idx = info .n_seqs
2111
- else :
2112
- outer_inp_idx = 0
2113
2113
n_outs = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
2114
2114
outs_shape = []
2115
2115
for idx in range (n_outs ):
2116
- mintap = abs (min (info .tap_array [idx ]))
2116
+ abs (min (info .tap_array [idx ]))
2117
2117
for k in info .tap_array [idx ]:
2118
2118
outs_shape += [input_shapes [idx + info .n_seqs + 1 ][1 :]]
2119
- if extra_infer_shape :
2120
- corresponding_tap = node .inputs [outer_inp_idx ][mintap + k ]
2121
- out_equivalent [self .inputs [inner_inp_idx ]] = corresponding_tap
2122
- inner_inp_idx += 1
2119
+ # if extra_infer_shape:
2120
+ # corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
2121
+ # out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
2122
+ # inner_inp_idx += 1
2123
2123
outer_inp_idx += 1
2124
2124
2125
2125
# shared_outs
0 commit comments