@@ -2167,8 +2167,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2167
2167
}
2168
2168
2169
2169
LogicalResult tosa::GatherOp::verify () {
2170
- return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2171
- /* outType = */ getOutput ().getType ());
2170
+ if (verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2171
+ /* outType = */ getOutput ().getType ())
2172
+ .failed ()) {
2173
+ return failure ();
2174
+ }
2175
+
2176
+ const ShapeAdaptor valuesShape (getValues ().getType ());
2177
+ const ShapeAdaptor indicesShape (getIndices ().getType ());
2178
+ const ShapeAdaptor outputShape (getOutput ().getType ());
2179
+
2180
+ int64_t N = ShapedType::kDynamic ;
2181
+ int64_t W = ShapedType::kDynamic ;
2182
+ int64_t C = ShapedType::kDynamic ;
2183
+
2184
+ if (valuesShape.hasRank ()) {
2185
+ N = valuesShape.getDimSize (0 );
2186
+ C = valuesShape.getDimSize (2 );
2187
+ }
2188
+ if (indicesShape.hasRank ()) {
2189
+ const int64_t indicesN = indicesShape.getDimSize (0 );
2190
+ W = indicesShape.getDimSize (1 );
2191
+ if (N == ShapedType::kDynamic )
2192
+ N = indicesN;
2193
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
2194
+ return emitOpError () << " requires indices dimension 0 to have size " << N
2195
+ << " , got " << indicesN;
2196
+ }
2197
+ if (outputShape.hasRank ()) {
2198
+ const int64_t outputN = outputShape.getDimSize (0 );
2199
+ const int64_t outputW = outputShape.getDimSize (1 );
2200
+ const int64_t outputC = outputShape.getDimSize (2 );
2201
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2202
+ N != outputN)
2203
+ return emitOpError () << " requires output dimension 0 to have size " << N
2204
+ << " , got " << outputN;
2205
+
2206
+ if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2207
+ W != outputW)
2208
+ return emitOpError () << " requires output dimension 1 to have size " << W
2209
+ << " , got " << outputW;
2210
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2211
+ C != outputC)
2212
+ return emitOpError () << " requires output dimension 2 to have size " << C
2213
+ << " , got " << outputC;
2214
+ }
2215
+ return success ();
2172
2216
}
2173
2217
2174
2218
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
0 commit comments