@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2262
2262
}
2263
2263
2264
2264
LogicalResult tosa::GatherOp::verify () {
2265
- return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2266
- /* outType = */ getOutput ().getType ());
2265
+ if (verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
2266
+ /* outType = */ getOutput ().getType ())
2267
+ .failed ()) {
2268
+ return failure ();
2269
+ }
2270
+
2271
+ const ShapeAdaptor valuesShape (getValues ().getType ());
2272
+ const ShapeAdaptor indicesShape (getIndices ().getType ());
2273
+ const ShapeAdaptor outputShape (getOutput ().getType ());
2274
+
2275
+ int64_t N = ShapedType::kDynamic ;
2276
+ int64_t W = ShapedType::kDynamic ;
2277
+ int64_t C = ShapedType::kDynamic ;
2278
+
2279
+ if (valuesShape.hasRank ()) {
2280
+ N = valuesShape.getDimSize (0 );
2281
+ C = valuesShape.getDimSize (2 );
2282
+ }
2283
+ if (indicesShape.hasRank ()) {
2284
+ const int64_t indicesN = indicesShape.getDimSize (0 );
2285
+ W = indicesShape.getDimSize (1 );
2286
+ if (N == ShapedType::kDynamic )
2287
+ N = indicesN;
2288
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
2289
+ return emitOpError () << " requires indices dimension 0 to have size " << N
2290
+ << " , got " << indicesN;
2291
+ }
2292
+ if (outputShape.hasRank ()) {
2293
+ const int64_t outputN = outputShape.getDimSize (0 );
2294
+ const int64_t outputW = outputShape.getDimSize (1 );
2295
+ const int64_t outputC = outputShape.getDimSize (2 );
2296
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2297
+ N != outputN)
2298
+ return emitOpError () << " requires output dimension 0 to have size " << N
2299
+ << " , got " << outputN;
2300
+
2301
+ if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2302
+ W != outputW)
2303
+ return emitOpError () << " requires output dimension 1 to have size " << W
2304
+ << " , got " << outputW;
2305
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2306
+ C != outputC)
2307
+ return emitOpError () << " requires output dimension 2 to have size " << C
2308
+ << " , got " << outputC;
2309
+ }
2310
+ return success ();
2267
2311
}
2268
2312
2269
2313
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
0 commit comments