@@ -269,15 +269,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
269
269
SmallVector<Type> opResultTypes;
270
270
SmallVector<Value> initTensors;
271
271
for (auto result : results) {
272
- auto resultType = result.getType ().template cast <ShapedType>();
273
- if (!resultType .hasStaticShape ())
272
+ auto resultTy = result.getType ().template cast <ShapedType>();
273
+ if (!resultTy .hasStaticShape ())
274
274
return rewriter.notifyMatchFailure (
275
275
operation,
276
276
" tosa to linalg conversion expects statically shaped tensors" );
277
277
278
278
initTensors.push_back (rewriter.create <linalg::InitTensorOp>(
279
- loc, ArrayRef<Value>({}), resultType .getShape (),
280
- resultType .getElementType ()));
279
+ loc, ArrayRef<Value>({}), resultTy .getShape (),
280
+ resultTy .getElementType ()));
281
281
opResultTypes.push_back (result.getType ());
282
282
}
283
283
@@ -330,6 +330,152 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
330
330
return success ();
331
331
}
332
332
333
+ // Returns the constant initial value for a given reduction operation. The
334
+ // attribute type varies depending on the element type required.
335
+ static Attribute createInitialValueForReduceOp (Operation *op, Type elementTy,
336
+ PatternRewriter &rewriter) {
337
+ if (isa<tosa::ReduceSumOp>(op) && elementTy.isa <FloatType>())
338
+ return rewriter.getFloatAttr (elementTy, 0.0 );
339
+
340
+ if (isa<tosa::ReduceSumOp>(op) && elementTy.isa <IntegerType>())
341
+ return rewriter.getIntegerAttr (elementTy, 0 );
342
+
343
+ if (isa<tosa::ReduceProdOp>(op) && elementTy.isa <FloatType>())
344
+ return rewriter.getFloatAttr (elementTy, 1.0 );
345
+
346
+ if (isa<tosa::ReduceProdOp>(op) && elementTy.isa <IntegerType>())
347
+ return rewriter.getIntegerAttr (elementTy, 1 );
348
+
349
+ if (isa<tosa::ReduceMinOp>(op) && elementTy.isa <FloatType>())
350
+ return rewriter.getFloatAttr (
351
+ elementTy, APFloat::getLargest (
352
+ elementTy.cast <FloatType>().getFloatSemantics (), false ));
353
+
354
+ if (isa<tosa::ReduceMinOp>(op) && elementTy.isa <IntegerType>())
355
+ return rewriter.getIntegerAttr (
356
+ elementTy, APInt::getSignedMaxValue (elementTy.getIntOrFloatBitWidth ()));
357
+
358
+ if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa <FloatType>())
359
+ return rewriter.getFloatAttr (
360
+ elementTy, APFloat::getLargest (
361
+ elementTy.cast <FloatType>().getFloatSemantics (), true ));
362
+
363
+ if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa <IntegerType>())
364
+ return rewriter.getIntegerAttr (
365
+ elementTy, APInt::getSignedMinValue (elementTy.getIntOrFloatBitWidth ()));
366
+
367
+ return {};
368
+ }
369
+
370
+ // Creates the body calculation for a reduction. The operations vary depending
371
+ // on the input type.
372
+ static Value createLinalgBodyCalculationForReduceOp (Operation *op,
373
+ ValueRange args,
374
+ Type elementTy,
375
+ PatternRewriter &rewriter) {
376
+ Location loc = op->getLoc ();
377
+ if (isa<tosa::ReduceSumOp>(op) && elementTy.isa <FloatType>()) {
378
+ return rewriter.create <AddFOp>(loc, args);
379
+ }
380
+
381
+ if (isa<tosa::ReduceSumOp>(op) && elementTy.isa <IntegerType>()) {
382
+ return rewriter.create <AddIOp>(loc, args);
383
+ }
384
+
385
+ if (isa<tosa::ReduceProdOp>(op) && elementTy.isa <FloatType>()) {
386
+ return rewriter.create <MulFOp>(loc, args);
387
+ }
388
+
389
+ if (isa<tosa::ReduceProdOp>(op) && elementTy.isa <IntegerType>()) {
390
+ return rewriter.create <MulIOp>(loc, args);
391
+ }
392
+
393
+ if (isa<tosa::ReduceMinOp>(op) && elementTy.isa <FloatType>()) {
394
+ auto predicate = rewriter.create <mlir::CmpFOp>(loc, CmpFPredicate::OLT,
395
+ args[0 ], args[1 ]);
396
+ return rewriter.create <mlir::SelectOp>(loc, predicate, args[0 ], args[1 ]);
397
+ }
398
+
399
+ if (isa<tosa::ReduceMinOp>(op) && elementTy.isa <IntegerType>()) {
400
+ auto predicate = rewriter.create <mlir::CmpIOp>(loc, CmpIPredicate::slt,
401
+ args[0 ], args[1 ]);
402
+ return rewriter.create <mlir::SelectOp>(loc, predicate, args[0 ], args[1 ]);
403
+ }
404
+
405
+ if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa <FloatType>()) {
406
+ auto predicate = rewriter.create <mlir::CmpFOp>(loc, CmpFPredicate::OGT,
407
+ args[0 ], args[1 ]);
408
+ return rewriter.create <mlir::SelectOp>(loc, predicate, args[0 ], args[1 ]);
409
+ }
410
+
411
+ if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa <IntegerType>()) {
412
+ auto predicate = rewriter.create <mlir::CmpIOp>(loc, CmpIPredicate::sgt,
413
+ args[0 ], args[1 ]);
414
+ return rewriter.create <mlir::SelectOp>(loc, predicate, args[0 ], args[1 ]);
415
+ }
416
+
417
+ return {};
418
+ }
419
+
420
+ // Performs the match and rewrite for reduction operations. This includes
421
+ // declaring a correctly sized initial value, and the linalg.generic operation
422
+ // that reduces across the specified axis.
423
+ static LogicalResult reduceMatchAndRewriteHelper (Operation *op, uint64_t axis,
424
+ PatternRewriter &rewriter) {
425
+ auto loc = op->getLoc ();
426
+ auto inputTy = op->getOperand (0 ).getType ().template cast <ShapedType>();
427
+ auto resultTy = op->getResult (0 ).getType ().template cast <ShapedType>();
428
+ auto elementTy = resultTy.getElementType ();
429
+ Value input = op->getOperand (0 );
430
+
431
+ // First fill the output buffer with the init value.
432
+ auto initTensor = rewriter
433
+ .create <linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
434
+ resultTy.getShape (),
435
+ resultTy.getElementType ())
436
+ .result ();
437
+
438
+ auto fillValueAttr = createInitialValueForReduceOp (op, elementTy, rewriter);
439
+ if (!fillValueAttr)
440
+ return rewriter.notifyMatchFailure (
441
+ op, " No initial value found for reduction operation" );
442
+
443
+ auto fillValue = rewriter.create <ConstantOp>(loc, fillValueAttr);
444
+ auto filledTensor =
445
+ rewriter.create <linalg::FillOp>(loc, initTensor, fillValue).result ();
446
+
447
+ SmallVector<AffineExpr, 2 > srcExprs;
448
+ SmallVector<AffineExpr, 2 > dstExprs;
449
+ SmallVector<StringRef, 4 > iteratorTypes;
450
+ for (unsigned int i = 0 , rank = inputTy.getRank (); i != rank; ++i) {
451
+ srcExprs.push_back (mlir::getAffineDimExpr (i, rewriter.getContext ()));
452
+
453
+ iteratorTypes.push_back (axis == i ? getReductionIteratorTypeName ()
454
+ : getParallelIteratorTypeName ());
455
+ if (axis != i)
456
+ dstExprs.push_back (mlir::getAffineDimExpr (i, rewriter.getContext ()));
457
+ }
458
+
459
+ bool didEncounterError = false ;
460
+ auto maps = AffineMap::inferFromExprList ({srcExprs, dstExprs});
461
+ auto linalgOp = rewriter.create <linalg::GenericOp>(
462
+ loc, resultTy, input, filledTensor, maps, iteratorTypes,
463
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
464
+ auto result = createLinalgBodyCalculationForReduceOp (
465
+ op, blockArgs, elementTy, rewriter);
466
+ if (result)
467
+ didEncounterError = true ;
468
+
469
+ nestedBuilder.create <linalg::YieldOp>(loc, result);
470
+ });
471
+
472
+ if (!didEncounterError)
473
+ return failure ();
474
+
475
+ rewriter.replaceOp (op, linalgOp.getOperation ()->getResults ());
476
+ return success ();
477
+ }
478
+
333
479
namespace {
334
480
335
481
template <typename SrcOp>
@@ -500,6 +646,17 @@ class IdentityNConverter : public OpRewritePattern<SrcOp> {
500
646
}
501
647
};
502
648
649
+ template <typename SrcOp>
650
+ class ReduceConverter : public OpRewritePattern <SrcOp> {
651
+ public:
652
+ using OpRewritePattern<SrcOp>::OpRewritePattern;
653
+
654
+ LogicalResult matchAndRewrite (SrcOp reduceOp,
655
+ PatternRewriter &rewriter) const final {
656
+ return reduceMatchAndRewriteHelper (reduceOp, reduceOp.axis (), rewriter);
657
+ }
658
+ };
659
+
503
660
} // namespace
504
661
505
662
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns (
@@ -521,6 +678,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
521
678
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
522
679
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
523
680
IdentityNConverter<tosa::IdentityOp>,
524
- IdentityNConverter<tosa::IdentityNOp>,
525
- ReshapeOpConverter, TransposeConverter>(context);
681
+ IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
682
+ ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
683
+ ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
684
+ TransposeConverter>(context);
526
685
}
0 commit comments