@@ -291,6 +291,148 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
291
291
result.types .push_back (outputType);
292
292
}
293
293
294
+ // ===----------------------------------------------------------------------===//
295
+ // TOSA Operator Return Type Inference.
296
+ // ===----------------------------------------------------------------------===//
297
+
298
+ static void getI64Values (ArrayAttr arrayAttr, SmallVector<int64_t > &values) {
299
+ for (auto it : arrayAttr) {
300
+ values.push_back (it.cast <IntegerAttr>().getValue ().getSExtValue ());
301
+ }
302
+ }
303
+
304
+ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents (
305
+ MLIRContext *context, ::llvm::Optional<Location> location,
306
+ ValueRange operands, DictionaryAttr attributes, RegionRange regions,
307
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
308
+ ShapedType type = operands.front ().getType ().cast <ShapedType>();
309
+
310
+ auto newShape = attributes.get (" new_shape" ).cast <ArrayAttr>();
311
+ llvm::SmallVector<int64_t > newShapeValue;
312
+ getI64Values (newShape, newShapeValue);
313
+
314
+ // We cannot infer from the total number of elements so we must take the
315
+ // shape attribute as exact.
316
+ if (!type.hasRank () || !type.hasStaticShape ()) {
317
+ inferredReturnShapes.push_back (ShapedTypeComponents (newShapeValue));
318
+ return success ();
319
+ }
320
+
321
+ // Determine the number of elements covered by the slice of all static
322
+ // dimensions. This allows us to infer the length of the remaining dynamic
323
+ // dimension.
324
+ int64_t numElements = type.getNumElements ();
325
+ int64_t staticMul = 1 ;
326
+ for (auto val : newShapeValue) {
327
+ if (val != -1 ) {
328
+ staticMul *= val;
329
+ }
330
+ }
331
+
332
+ // Determine the length of the dynamic dimension.
333
+ for (auto &val : newShapeValue) {
334
+ if (val == -1 )
335
+ val = numElements / staticMul;
336
+ }
337
+
338
+ inferredReturnShapes.push_back (ShapedTypeComponents (newShapeValue));
339
+ return success ();
340
+ }
341
+
342
+ static LogicalResult resolveBroadcastShape (ValueRange operands,
343
+ SmallVector<int64_t > &outShape) {
344
+ int64_t outRank = 0 ;
345
+ for (auto operand : operands) {
346
+ auto type = operand.getType ().cast <ShapedType>();
347
+ if (!type.hasRank ())
348
+ return failure ();
349
+ outRank = std::max<int64_t >(outRank, type.getRank ());
350
+ }
351
+
352
+ outShape.resize (outRank, 1 );
353
+
354
+ for (auto operand : operands) {
355
+ auto type = operand.getType ().cast <ShapedType>();
356
+ auto shape = type.getShape ();
357
+ auto rankDiff = outShape.size () - shape.size ();
358
+
359
+ for (size_t i = 0 ; i < shape.size (); i++) {
360
+ auto dim1 = outShape[i + rankDiff];
361
+ auto dim2 = shape[i];
362
+ auto resolvedDim = dim1;
363
+
364
+ if (dim1 == 1 ) {
365
+ resolvedDim = dim2;
366
+ } else if (dim2 == 1 ) {
367
+ resolvedDim = dim1;
368
+ } else if (dim1 != dim2) {
369
+ return failure ();
370
+ }
371
+ outShape[i + rankDiff] = resolvedDim;
372
+ }
373
+ }
374
+
375
+ return success ();
376
+ }
377
+
378
+ static LogicalResult NAryInferReturnTypes (
379
+ ValueRange operands,
380
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
381
+ llvm::SmallVector<int64_t > outShape;
382
+ if (resolveBroadcastShape (operands, outShape).failed ()) {
383
+ inferredReturnShapes.push_back (ShapedTypeComponents ());
384
+ } else {
385
+ inferredReturnShapes.push_back (ShapedTypeComponents (outShape));
386
+ }
387
+ return success ();
388
+ }
389
+
390
+ #define NARY_SHAPE_INFER (OP ) \
391
+ LogicalResult OP::inferReturnTypeComponents ( \
392
+ MLIRContext *context, ::llvm::Optional<Location> location, \
393
+ ValueRange operands, DictionaryAttr attributes, RegionRange regions, \
394
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
395
+ return NAryInferReturnTypes (operands, inferredReturnShapes); \
396
+ }
397
+
398
+ NARY_SHAPE_INFER (tosa::AbsOp)
399
+ NARY_SHAPE_INFER(tosa::AddOp)
400
+ NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
401
+ NARY_SHAPE_INFER(tosa::BitwiseAndOp)
402
+ NARY_SHAPE_INFER(tosa::BitwiseOrOp)
403
+ NARY_SHAPE_INFER(tosa::BitwiseXorOp)
404
+ NARY_SHAPE_INFER(tosa::BitwiseNotOp)
405
+ NARY_SHAPE_INFER(tosa::CeilOp)
406
+ NARY_SHAPE_INFER(tosa::ClampOp)
407
+ NARY_SHAPE_INFER(tosa::ClzOp)
408
+ NARY_SHAPE_INFER(tosa::DivOp)
409
+ NARY_SHAPE_INFER(tosa::EqualOp)
410
+ NARY_SHAPE_INFER(tosa::ExpOp)
411
+ NARY_SHAPE_INFER(tosa::FloorOp)
412
+ NARY_SHAPE_INFER(tosa::GreaterEqualOp)
413
+ NARY_SHAPE_INFER(tosa::GreaterOp)
414
+ NARY_SHAPE_INFER(tosa::LogOp)
415
+ NARY_SHAPE_INFER(tosa::LogicalAndOp)
416
+ NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
417
+ NARY_SHAPE_INFER(tosa::LogicalNotOp)
418
+ NARY_SHAPE_INFER(tosa::LogicalOrOp)
419
+ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
420
+ NARY_SHAPE_INFER(tosa::LogicalXorOp)
421
+ NARY_SHAPE_INFER(tosa::MaximumOp)
422
+ NARY_SHAPE_INFER(tosa::MinimumOp)
423
+ NARY_SHAPE_INFER(tosa::MulOp)
424
+ NARY_SHAPE_INFER(tosa::NegateOp)
425
+ NARY_SHAPE_INFER(tosa::PowOp)
426
+ NARY_SHAPE_INFER(tosa::ReciprocalOp)
427
+ NARY_SHAPE_INFER(tosa::ReluNOp)
428
+ NARY_SHAPE_INFER(tosa::ReverseOp)
429
+ NARY_SHAPE_INFER(tosa::RsqrtOp)
430
+ NARY_SHAPE_INFER(tosa::SelectOp)
431
+ NARY_SHAPE_INFER(tosa::SubOp)
432
+ NARY_SHAPE_INFER(tosa::TanhOp)
433
+ NARY_SHAPE_INFER(tosa::SigmoidOp)
434
+ #undef PRED_SHAPE_INFER
435
+
294
436
// ===----------------------------------------------------------------------===//
295
437
// TOSA Operator Definitions.
296
438
// ===----------------------------------------------------------------------===//
0 commit comments