@@ -2504,39 +2504,34 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
2504
2504
}
2505
2505
};
2506
2506
2507
- template <typename SrcOp>
2508
- class Pool2dConverter : public OpRewritePattern <SrcOp> {
2507
+ class MaxPool2dConverter : public OpRewritePattern <tosa::MaxPool2dOp> {
2509
2508
public:
2510
- using OpRewritePattern<SrcOp >::OpRewritePattern;
2509
+ using OpRewritePattern<tosa::MaxPool2dOp >::OpRewritePattern;
2511
2510
2512
- LogicalResult matchAndRewrite (SrcOp op,
2511
+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
2513
2512
PatternRewriter &rewriter) const final {
2514
2513
Location loc = op.getLoc ();
2515
2514
Value input = op.input ();
2516
2515
ShapedType inputTy = input.getType ().cast <ShapedType>();
2517
- Type inElementTy = inputTy.getElementType ();
2518
2516
2519
2517
ShapedType resultTy = op.getType ().template cast <ShapedType>();
2520
- Type outElementTy = inputTy.getElementType ();
2518
+ Type resultETy = inputTy.getElementType ();
2521
2519
2522
2520
if (!inputTy.hasStaticShape ())
2523
2521
return failure ();
2524
2522
2525
2523
// Determine what the initial value needs to be for the max pool op.
2526
2524
Attribute initialAttr;
2527
- if (isa<tosa::MaxPool2dOp>(op) && outElementTy .isF32 ())
2525
+ if (resultETy .isF32 ())
2528
2526
initialAttr = rewriter.getFloatAttr (
2529
- outElementTy ,
2530
- APFloat::getLargest (
2531
- outElementTy. cast <FloatType>(). getFloatSemantics (), true ));
2527
+ resultETy ,
2528
+ APFloat::getLargest (resultETy. cast <FloatType>(). getFloatSemantics (),
2529
+ true ));
2532
2530
2533
- if (isa<tosa::MaxPool2dOp>(op) && outElementTy .isa <IntegerType>())
2531
+ if (resultETy .isa <IntegerType>())
2534
2532
initialAttr = rewriter.getIntegerAttr (
2535
- outElementTy,
2536
- APInt::getSignedMinValue (outElementTy.getIntOrFloatBitWidth ()));
2537
-
2538
- if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa <FloatType>())
2539
- initialAttr = rewriter.getZeroAttr (outElementTy);
2533
+ resultETy,
2534
+ APInt::getSignedMinValue (resultETy.getIntOrFloatBitWidth ()));
2540
2535
2541
2536
if (!initialAttr)
2542
2537
return rewriter.notifyMatchFailure (
@@ -2566,93 +2561,216 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
2566
2561
rewriter.create <linalg::FillOp>(loc, initialValue, initTensor).result ();
2567
2562
2568
2563
Value fakeWindowDims =
2569
- rewriter.create <linalg::InitTensorOp>(loc, kernel, outElementTy );
2564
+ rewriter.create <linalg::InitTensorOp>(loc, kernel, resultETy );
2570
2565
2571
- if (isa<tosa::MaxPool2dOp>(op)) {
2572
- rewriter. replaceOpWithNewOp <linalg::PoolingNhwcMaxOp>(
2573
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2574
- filledInitTensor, strideAttr, dilationAttr );
2575
- return success ();
2576
- }
2566
+ rewriter. replaceOpWithNewOp <linalg::PoolingNhwcMaxOp>(
2567
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2568
+ filledInitTensor, strideAttr, dilationAttr);
2569
+ return success ( );
2570
+ }
2571
+ };
2577
2572
2578
- if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32 ()) {
2579
- Value poolingOp = rewriter
2580
- .create <linalg::PoolingNhwcSumOp>(
2581
- loc, ArrayRef<Type>{resultTy},
2582
- ValueRange{paddedInput, fakeWindowDims},
2583
- filledInitTensor, strideAttr, dilationAttr)
2584
- .getResult (0 );
2585
- auto poolingOpTy = poolingOp.getType ().cast <ShapedType>();
2586
- auto affineMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
2587
- auto genericOp = rewriter.create <linalg::GenericOp>(
2588
- loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
2589
- ArrayRef<AffineMap>({affineMap}),
2590
- getNParallelLoopsAttrs (resultTy.getRank ()),
2591
- [&](OpBuilder &b, Location loc, ValueRange args) {
2592
- auto zero = rewriter.create <ConstantIndexOp>(loc, 0 );
2593
- auto one = rewriter.create <ConstantIndexOp>(loc, 1 );
2594
- auto iH = rewriter.create <ConstantIndexOp>(
2595
- loc, poolingOpTy.getDimSize (1 ) - 1 );
2596
- auto iW = rewriter.create <ConstantIndexOp>(
2597
- loc, poolingOpTy.getDimSize (2 ) - 1 );
2598
-
2599
- // Compute the indices from either end.
2600
- auto y0 = rewriter.create <linalg::IndexOp>(loc, 1 );
2601
- auto x0 = rewriter.create <linalg::IndexOp>(loc, 2 );
2602
- auto y1 = rewriter.create <SubIOp>(loc, iH, y0);
2603
- auto x1 = rewriter.create <SubIOp>(loc, iW, x0);
2604
-
2605
- // Determines what the portion of valid input is covered by the
2606
- // kernel.
2607
- auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2608
- if (pad == 0 )
2609
- return v;
2610
-
2611
- auto padVal = rewriter.create <ConstantIndexOp>(loc, pad);
2612
- Value dx = rewriter.create <SubIOp>(loc, x, padVal);
2613
-
2614
- Value cmp = rewriter.create <mlir::CmpIOp>(loc, CmpIPredicate::slt,
2615
- dx, zero);
2616
- Value offset =
2617
- rewriter.create <mlir::SelectOp>(loc, cmp, dx, zero);
2618
- return rewriter.create <mlir::AddIOp>(loc, v, offset)
2619
- ->getResult (0 );
2620
- };
2621
-
2622
- // Compute the vertical component of coverage.
2623
- auto kH0 = rewriter.create <ConstantIndexOp>(loc, kernel[0 ]);
2624
- auto kH1 = padFn (kH0 , y0, pad[2 ]);
2625
- auto kH2 = padFn (kH1 , y1, pad[3 ]);
2626
- auto kHCmp =
2627
- rewriter.create <CmpIOp>(loc, CmpIPredicate::slt, kH2 , one);
2628
- auto kH3 = rewriter.create <SelectOp>(loc, kHCmp , one, kH2 );
2629
-
2630
- // compute teh horizontal component of coverage.
2631
- auto kW0 = rewriter.create <ConstantIndexOp>(loc, kernel[1 ]);
2632
- auto kW1 = padFn (kW0 , x0, pad[4 ]);
2633
- auto kW2 = padFn (kW1 , x1, pad[5 ]);
2634
- auto kWCmp =
2635
- rewriter.create <CmpIOp>(loc, CmpIPredicate::slt, kW2 , one);
2636
- auto kW3 = rewriter.create <SelectOp>(loc, kWCmp , one, kW2 );
2637
-
2638
- // Compute the total number of elements and normalize.
2639
- Value count = rewriter.create <MulIOp>(loc, kH3 , kW3 );
2640
- auto countI = rewriter.create <mlir::IndexCastOp>(
2641
- loc, rewriter.getI32Type (), count);
2573
+ class AvgPool2dConverter : public OpRewritePattern <tosa::AvgPool2dOp> {
2574
+ public:
2575
+ using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
2576
+
2577
+ LogicalResult matchAndRewrite (tosa::AvgPool2dOp op,
2578
+ PatternRewriter &rewriter) const final {
2579
+ Location loc = op.getLoc ();
2580
+ Value input = op.input ();
2581
+ ShapedType inputTy = input.getType ().cast <ShapedType>();
2582
+ Type inElementTy = inputTy.getElementType ();
2583
+
2584
+ ShapedType resultTy = op.getType ().template cast <ShapedType>();
2585
+ Type resultETy = inputTy.getElementType ();
2586
+
2587
+ Type accETy =
2588
+ inElementTy.isa <IntegerType>() ? rewriter.getI32Type () : inElementTy;
2589
+ ShapedType accTy = resultTy.clone (accETy);
2590
+
2591
+ if (!inputTy.hasStaticShape ())
2592
+ return failure ();
2593
+
2594
+ // Apply padding as necessary.
2595
+ llvm::SmallVector<int64_t > pad;
2596
+ pad.resize (2 , 0 );
2597
+ getValuesFromIntArrayAttribute (op.pad (), pad);
2598
+ pad.resize (pad.size () + 2 , 0 );
2599
+ Attribute initialAttr = rewriter.getZeroAttr (accETy);
2600
+ Value paddedInput = applyPad (loc, input, pad, initialAttr, rewriter);
2601
+
2602
+ Value initialValue = rewriter.create <ConstantOp>(loc, initialAttr);
2603
+
2604
+ SmallVector<int64_t > kernel, stride;
2605
+ getValuesFromIntArrayAttribute (op.kernel (), kernel);
2606
+ getValuesFromIntArrayAttribute (op.stride (), stride);
2607
+
2608
+ Attribute strideAttr = rewriter.getI64VectorAttr (stride);
2609
+ Attribute dilationAttr = rewriter.getI64VectorAttr ({1 , 1 });
2610
+
2611
+ // Create the linalg op that performs pooling.
2612
+ Value poolInitTensor =
2613
+ rewriter.create <linalg::InitTensorOp>(loc, accTy.getShape (), accETy);
2614
+
2615
+ Value filledInitTensor =
2616
+ rewriter.create <linalg::FillOp>(loc, initialValue, poolInitTensor)
2617
+ .result ();
2618
+
2619
+ Value fakeWindowDims =
2620
+ rewriter.create <linalg::InitTensorOp>(loc, kernel, accETy);
2621
+
2622
+ // Sum across the pooled region.
2623
+ Value poolingOp = rewriter
2624
+ .create <linalg::PoolingNhwcSumOp>(
2625
+ loc, ArrayRef<Type>{accTy},
2626
+ ValueRange{paddedInput, fakeWindowDims},
2627
+ filledInitTensor, strideAttr, dilationAttr)
2628
+ .getResult (0 );
2629
+
2630
+ // Normalize the summed value by the number of elements grouped in each
2631
+ // pool.
2632
+ auto poolingOpTy = poolingOp.getType ().cast <ShapedType>();
2633
+ auto affineMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
2634
+
2635
+ Value genericInitTensor = rewriter.create <linalg::InitTensorOp>(
2636
+ loc, resultTy.getShape (), resultETy);
2637
+
2638
+ auto genericOp = rewriter.create <linalg::GenericOp>(
2639
+ loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
2640
+ ValueRange{genericInitTensor},
2641
+ ArrayRef<AffineMap>({affineMap, affineMap}),
2642
+ getNParallelLoopsAttrs (resultTy.getRank ()),
2643
+ [&](OpBuilder &b, Location loc, ValueRange args) {
2644
+ auto zero = rewriter.create <ConstantIndexOp>(loc, 0 );
2645
+ auto one = rewriter.create <ConstantIndexOp>(loc, 1 );
2646
+ auto iH = rewriter.create <ConstantIndexOp>(
2647
+ loc, poolingOpTy.getDimSize (1 ) - 1 );
2648
+ auto iW = rewriter.create <ConstantIndexOp>(
2649
+ loc, poolingOpTy.getDimSize (2 ) - 1 );
2650
+
2651
+ // Compute the indices from either end.
2652
+ auto y0 = rewriter.create <linalg::IndexOp>(loc, 1 );
2653
+ auto x0 = rewriter.create <linalg::IndexOp>(loc, 2 );
2654
+ auto y1 = rewriter.create <SubIOp>(loc, iH, y0);
2655
+ auto x1 = rewriter.create <SubIOp>(loc, iW, x0);
2656
+
2657
+ // Determines what the portion of valid input is covered by the
2658
+ // kernel.
2659
+ auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2660
+ if (pad == 0 )
2661
+ return v;
2662
+
2663
+ auto padVal = rewriter.create <ConstantIndexOp>(loc, pad);
2664
+ Value dx = rewriter.create <SubIOp>(loc, x, padVal);
2665
+
2666
+ Value cmp = rewriter.create <mlir::CmpIOp>(loc, CmpIPredicate::slt,
2667
+ dx, zero);
2668
+ Value offset = rewriter.create <mlir::SelectOp>(loc, cmp, dx, zero);
2669
+ return rewriter.create <mlir::AddIOp>(loc, v, offset)->getResult (0 );
2670
+ };
2671
+
2672
+ // Compute the vertical component of coverage.
2673
+ auto kH0 = rewriter.create <ConstantIndexOp>(loc, kernel[0 ]);
2674
+ auto kH1 = padFn (kH0 , y0, pad[2 ]);
2675
+ auto kH2 = padFn (kH1 , y1, pad[3 ]);
2676
+ auto kHCmp =
2677
+ rewriter.create <CmpIOp>(loc, CmpIPredicate::slt, kH2 , one);
2678
+ auto kH3 = rewriter.create <SelectOp>(loc, kHCmp , one, kH2 );
2679
+
2680
+ // compute the horizontal component of coverage.
2681
+ auto kW0 = rewriter.create <ConstantIndexOp>(loc, kernel[1 ]);
2682
+ auto kW1 = padFn (kW0 , x0, pad[4 ]);
2683
+ auto kW2 = padFn (kW1 , x1, pad[5 ]);
2684
+ auto kWCmp =
2685
+ rewriter.create <CmpIOp>(loc, CmpIPredicate::slt, kW2 , one);
2686
+ auto kW3 = rewriter.create <SelectOp>(loc, kWCmp , one, kW2 );
2687
+
2688
+ // Compute the total number of elements and normalize.
2689
+ Value count = rewriter.create <MulIOp>(loc, kH3 , kW3 );
2690
+ auto countI = rewriter.create <mlir::IndexCastOp>(
2691
+ loc, rewriter.getI32Type (), count);
2692
+
2693
+ // Divide by the number of summed values. For floats this is just
2694
+ // a div however for quantized values input normalization had
2695
+ // to be applied.
2696
+ Value poolVal = args[0 ];
2697
+ if (accETy.isa <FloatType>()) {
2642
2698
auto countF =
2643
2699
rewriter.create <mlir::SIToFPOp>(loc, inElementTy, countI);
2700
+ poolVal =
2701
+ rewriter.create <DivFOp>(loc, poolVal, countF)->getResult (0 );
2702
+ } else {
2644
2703
2645
- auto div =
2646
- rewriter.create <DivFOp>(loc, args[0 ], countF)->getResult (0 );
2704
+ // If we have quantization information we need to apply an offset
2705
+ // for the input zp value.
2706
+ if (op.quantization_info ()) {
2707
+ auto quantizationInfo = op.quantization_info ().getValue ();
2708
+ auto inputZp = rewriter.create <mlir::ConstantOp>(
2709
+ loc, quantizationInfo.input_zp ());
2710
+ Value offset =
2711
+ rewriter.create <mlir::MulIOp>(loc, accETy, countI, inputZp);
2712
+ poolVal = rewriter.create <SubIOp>(loc, accETy, poolVal, offset);
2713
+ }
2647
2714
2648
- rewriter.create <linalg::YieldOp>(loc, div);
2649
- });
2715
+ // Compute the multiplier and shift values for the quantization
2716
+ // normalization. Preferably we would want to compute more bits
2717
+ // however 32-bits should be enough for compute. Honestly we
2718
+ // should probably straight divide.
2719
+ int64_t numerator = ((1 << 30 ) + 1 );
2720
+ int64_t shift = 30 ;
2721
+
2722
+ Value numeratorVal = rewriter.create <ConstantOp>(
2723
+ loc, rewriter.getI32IntegerAttr (numerator));
2724
+ Value multiplierVal =
2725
+ rewriter
2726
+ .create <UnsignedDivIOp>(loc, rewriter.getI32Type (),
2727
+ numeratorVal, countI)
2728
+ .getResult ();
2729
+ Value shiftVal = rewriter.create <ConstantOp>(
2730
+ loc, rewriter.getI8IntegerAttr (shift));
2731
+
2732
+ auto scaled =
2733
+ rewriter
2734
+ .create <tosa::ApplyScaleOp>(
2735
+ loc, rewriter.getI32Type (), poolVal, multiplierVal,
2736
+ shiftVal, rewriter.getBoolAttr (false ))
2737
+ .getResult ();
2738
+
2739
+ // If we have quantization information we need to apply output
2740
+ // zeropoint.
2741
+ if (op.quantization_info ()) {
2742
+ auto quantizationInfo = op.quantization_info ().getValue ();
2743
+ auto outputZp = rewriter.create <mlir::ConstantOp>(
2744
+ loc, quantizationInfo.output_zp ());
2745
+ scaled =
2746
+ rewriter.create <AddIOp>(loc, scaled, outputZp).getResult ();
2747
+ }
2650
2748
2651
- rewriter.replaceOp (op, genericOp.getResult (0 ));
2652
- return success ();
2653
- }
2749
+ // Apply Clip.
2750
+ int64_t outBitwidth = resultETy.getIntOrFloatBitWidth ();
2751
+
2752
+ auto min = rewriter.create <ConstantOp>(
2753
+ loc, rewriter.getIntegerAttr (
2754
+ accETy,
2755
+ APInt::getSignedMinValue (outBitwidth).getSExtValue ()));
2756
+ auto max = rewriter.create <ConstantOp>(
2757
+ loc, rewriter.getIntegerAttr (
2758
+ accETy,
2759
+ APInt::getSignedMaxValue (outBitwidth).getSExtValue ()));
2760
+ auto clamp = clampHelper<mlir::CmpIOp>(
2761
+ loc, scaled, min, max, CmpIPredicate::slt, rewriter);
2762
+
2763
+ // Convert type.
2764
+ poolVal = rewriter.create <TruncateIOp>(loc, resultETy, clamp);
2765
+ }
2654
2766
2655
- return failure ();
2767
+ // Cast to output type.
2768
+
2769
+ rewriter.create <linalg::YieldOp>(loc, poolVal);
2770
+ });
2771
+
2772
+ rewriter.replaceOp (op, genericOp.getResult (0 ));
2773
+ return success ();
2656
2774
}
2657
2775
};
2658
2776
@@ -2719,8 +2837,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
2719
2837
TileConverter,
2720
2838
TransposeConverter,
2721
2839
MatMulConverter,
2722
- Pool2dConverter<tosa::AvgPool2dOp> ,
2723
- Pool2dConverter<tosa::MaxPool2dOp> ,
2840
+ MaxPool2dConverter ,
2841
+ AvgPool2dConverter ,
2724
2842
FullyConnectedConverter>(patterns->getContext ());
2725
2843
// clang-format on
2726
2844
}
0 commit comments