Skip to content

Commit 5541a05

Browse files
committed
[mlir][tosa] Quantized tosa.avg_pool2d lowering to linalg
Includes the quantized version of average pool lowering to linalg dialect. This includes a lit test for the transform. It is not 100% correct as the multiplier / shift should be done in i64 however this is negligable rounding difference. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D108676
1 parent 4ef1770 commit 5541a05

File tree

2 files changed

+254
-106
lines changed

2 files changed

+254
-106
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 215 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,39 +2504,34 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
25042504
}
25052505
};
25062506

2507-
template <typename SrcOp>
2508-
class Pool2dConverter : public OpRewritePattern<SrcOp> {
2507+
class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
25092508
public:
2510-
using OpRewritePattern<SrcOp>::OpRewritePattern;
2509+
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
25112510

2512-
LogicalResult matchAndRewrite(SrcOp op,
2511+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
25132512
PatternRewriter &rewriter) const final {
25142513
Location loc = op.getLoc();
25152514
Value input = op.input();
25162515
ShapedType inputTy = input.getType().cast<ShapedType>();
2517-
Type inElementTy = inputTy.getElementType();
25182516

25192517
ShapedType resultTy = op.getType().template cast<ShapedType>();
2520-
Type outElementTy = inputTy.getElementType();
2518+
Type resultETy = inputTy.getElementType();
25212519

25222520
if (!inputTy.hasStaticShape())
25232521
return failure();
25242522

25252523
// Determine what the initial value needs to be for the max pool op.
25262524
Attribute initialAttr;
2527-
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
2525+
if (resultETy.isF32())
25282526
initialAttr = rewriter.getFloatAttr(
2529-
outElementTy,
2530-
APFloat::getLargest(
2531-
outElementTy.cast<FloatType>().getFloatSemantics(), true));
2527+
resultETy,
2528+
APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
2529+
true));
25322530

2533-
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
2531+
if (resultETy.isa<IntegerType>())
25342532
initialAttr = rewriter.getIntegerAttr(
2535-
outElementTy,
2536-
APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
2537-
2538-
if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
2539-
initialAttr = rewriter.getZeroAttr(outElementTy);
2533+
resultETy,
2534+
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
25402535

25412536
if (!initialAttr)
25422537
return rewriter.notifyMatchFailure(
@@ -2566,93 +2561,216 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
25662561
rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
25672562

25682563
Value fakeWindowDims =
2569-
rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
2564+
rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
25702565

2571-
if (isa<tosa::MaxPool2dOp>(op)) {
2572-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2573-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2574-
filledInitTensor, strideAttr, dilationAttr);
2575-
return success();
2576-
}
2566+
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2567+
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2568+
filledInitTensor, strideAttr, dilationAttr);
2569+
return success();
2570+
}
2571+
};
25772572

2578-
if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
2579-
Value poolingOp = rewriter
2580-
.create<linalg::PoolingNhwcSumOp>(
2581-
loc, ArrayRef<Type>{resultTy},
2582-
ValueRange{paddedInput, fakeWindowDims},
2583-
filledInitTensor, strideAttr, dilationAttr)
2584-
.getResult(0);
2585-
auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2586-
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2587-
auto genericOp = rewriter.create<linalg::GenericOp>(
2588-
loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
2589-
ArrayRef<AffineMap>({affineMap}),
2590-
getNParallelLoopsAttrs(resultTy.getRank()),
2591-
[&](OpBuilder &b, Location loc, ValueRange args) {
2592-
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2593-
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2594-
auto iH = rewriter.create<ConstantIndexOp>(
2595-
loc, poolingOpTy.getDimSize(1) - 1);
2596-
auto iW = rewriter.create<ConstantIndexOp>(
2597-
loc, poolingOpTy.getDimSize(2) - 1);
2598-
2599-
// Compute the indices from either end.
2600-
auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2601-
auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2602-
auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2603-
auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2604-
2605-
// Determines what the portion of valid input is covered by the
2606-
// kernel.
2607-
auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2608-
if (pad == 0)
2609-
return v;
2610-
2611-
auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2612-
Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2613-
2614-
Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2615-
dx, zero);
2616-
Value offset =
2617-
rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2618-
return rewriter.create<mlir::AddIOp>(loc, v, offset)
2619-
->getResult(0);
2620-
};
2621-
2622-
// Compute the vertical component of coverage.
2623-
auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2624-
auto kH1 = padFn(kH0, y0, pad[2]);
2625-
auto kH2 = padFn(kH1, y1, pad[3]);
2626-
auto kHCmp =
2627-
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2628-
auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2629-
2630-
// compute teh horizontal component of coverage.
2631-
auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2632-
auto kW1 = padFn(kW0, x0, pad[4]);
2633-
auto kW2 = padFn(kW1, x1, pad[5]);
2634-
auto kWCmp =
2635-
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2636-
auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2637-
2638-
// Compute the total number of elements and normalize.
2639-
Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2640-
auto countI = rewriter.create<mlir::IndexCastOp>(
2641-
loc, rewriter.getI32Type(), count);
2573+
class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
2574+
public:
2575+
using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
2576+
2577+
LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
2578+
PatternRewriter &rewriter) const final {
2579+
Location loc = op.getLoc();
2580+
Value input = op.input();
2581+
ShapedType inputTy = input.getType().cast<ShapedType>();
2582+
Type inElementTy = inputTy.getElementType();
2583+
2584+
ShapedType resultTy = op.getType().template cast<ShapedType>();
2585+
Type resultETy = inputTy.getElementType();
2586+
2587+
Type accETy =
2588+
inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
2589+
ShapedType accTy = resultTy.clone(accETy);
2590+
2591+
if (!inputTy.hasStaticShape())
2592+
return failure();
2593+
2594+
// Apply padding as necessary.
2595+
llvm::SmallVector<int64_t> pad;
2596+
pad.resize(2, 0);
2597+
getValuesFromIntArrayAttribute(op.pad(), pad);
2598+
pad.resize(pad.size() + 2, 0);
2599+
Attribute initialAttr = rewriter.getZeroAttr(accETy);
2600+
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2601+
2602+
Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2603+
2604+
SmallVector<int64_t> kernel, stride;
2605+
getValuesFromIntArrayAttribute(op.kernel(), kernel);
2606+
getValuesFromIntArrayAttribute(op.stride(), stride);
2607+
2608+
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2609+
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2610+
2611+
// Create the linalg op that performs pooling.
2612+
Value poolInitTensor =
2613+
rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
2614+
2615+
Value filledInitTensor =
2616+
rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
2617+
.result();
2618+
2619+
Value fakeWindowDims =
2620+
rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
2621+
2622+
// Sum across the pooled region.
2623+
Value poolingOp = rewriter
2624+
.create<linalg::PoolingNhwcSumOp>(
2625+
loc, ArrayRef<Type>{accTy},
2626+
ValueRange{paddedInput, fakeWindowDims},
2627+
filledInitTensor, strideAttr, dilationAttr)
2628+
.getResult(0);
2629+
2630+
// Normalize the summed value by the number of elements grouped in each
2631+
// pool.
2632+
auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2633+
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2634+
2635+
Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
2636+
loc, resultTy.getShape(), resultETy);
2637+
2638+
auto genericOp = rewriter.create<linalg::GenericOp>(
2639+
loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
2640+
ValueRange{genericInitTensor},
2641+
ArrayRef<AffineMap>({affineMap, affineMap}),
2642+
getNParallelLoopsAttrs(resultTy.getRank()),
2643+
[&](OpBuilder &b, Location loc, ValueRange args) {
2644+
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2645+
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2646+
auto iH = rewriter.create<ConstantIndexOp>(
2647+
loc, poolingOpTy.getDimSize(1) - 1);
2648+
auto iW = rewriter.create<ConstantIndexOp>(
2649+
loc, poolingOpTy.getDimSize(2) - 1);
2650+
2651+
// Compute the indices from either end.
2652+
auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2653+
auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2654+
auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2655+
auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2656+
2657+
// Determines what the portion of valid input is covered by the
2658+
// kernel.
2659+
auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2660+
if (pad == 0)
2661+
return v;
2662+
2663+
auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2664+
Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2665+
2666+
Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2667+
dx, zero);
2668+
Value offset = rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2669+
return rewriter.create<mlir::AddIOp>(loc, v, offset)->getResult(0);
2670+
};
2671+
2672+
// Compute the vertical component of coverage.
2673+
auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2674+
auto kH1 = padFn(kH0, y0, pad[2]);
2675+
auto kH2 = padFn(kH1, y1, pad[3]);
2676+
auto kHCmp =
2677+
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2678+
auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2679+
2680+
// compute the horizontal component of coverage.
2681+
auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2682+
auto kW1 = padFn(kW0, x0, pad[4]);
2683+
auto kW2 = padFn(kW1, x1, pad[5]);
2684+
auto kWCmp =
2685+
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2686+
auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2687+
2688+
// Compute the total number of elements and normalize.
2689+
Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2690+
auto countI = rewriter.create<mlir::IndexCastOp>(
2691+
loc, rewriter.getI32Type(), count);
2692+
2693+
// Divide by the number of summed values. For floats this is just
2694+
// a div however for quantized values input normalization had
2695+
// to be applied.
2696+
Value poolVal = args[0];
2697+
if (accETy.isa<FloatType>()) {
26422698
auto countF =
26432699
rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
2700+
poolVal =
2701+
rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
2702+
} else {
26442703

2645-
auto div =
2646-
rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
2704+
// If we have quantization information we need to apply an offset
2705+
// for the input zp value.
2706+
if (op.quantization_info()) {
2707+
auto quantizationInfo = op.quantization_info().getValue();
2708+
auto inputZp = rewriter.create<mlir::ConstantOp>(
2709+
loc, quantizationInfo.input_zp());
2710+
Value offset =
2711+
rewriter.create<mlir::MulIOp>(loc, accETy, countI, inputZp);
2712+
poolVal = rewriter.create<SubIOp>(loc, accETy, poolVal, offset);
2713+
}
26472714

2648-
rewriter.create<linalg::YieldOp>(loc, div);
2649-
});
2715+
// Compute the multiplier and shift values for the quantization
2716+
// normalization. Preferably we would want to compute more bits
2717+
// however 32-bits should be enough for compute. Honestly we
2718+
// should probably straight divide.
2719+
int64_t numerator = ((1 << 30) + 1);
2720+
int64_t shift = 30;
2721+
2722+
Value numeratorVal = rewriter.create<ConstantOp>(
2723+
loc, rewriter.getI32IntegerAttr(numerator));
2724+
Value multiplierVal =
2725+
rewriter
2726+
.create<UnsignedDivIOp>(loc, rewriter.getI32Type(),
2727+
numeratorVal, countI)
2728+
.getResult();
2729+
Value shiftVal = rewriter.create<ConstantOp>(
2730+
loc, rewriter.getI8IntegerAttr(shift));
2731+
2732+
auto scaled =
2733+
rewriter
2734+
.create<tosa::ApplyScaleOp>(
2735+
loc, rewriter.getI32Type(), poolVal, multiplierVal,
2736+
shiftVal, rewriter.getBoolAttr(false))
2737+
.getResult();
2738+
2739+
// If we have quantization information we need to apply output
2740+
// zeropoint.
2741+
if (op.quantization_info()) {
2742+
auto quantizationInfo = op.quantization_info().getValue();
2743+
auto outputZp = rewriter.create<mlir::ConstantOp>(
2744+
loc, quantizationInfo.output_zp());
2745+
scaled =
2746+
rewriter.create<AddIOp>(loc, scaled, outputZp).getResult();
2747+
}
26502748

2651-
rewriter.replaceOp(op, genericOp.getResult(0));
2652-
return success();
2653-
}
2749+
// Apply Clip.
2750+
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
2751+
2752+
auto min = rewriter.create<ConstantOp>(
2753+
loc, rewriter.getIntegerAttr(
2754+
accETy,
2755+
APInt::getSignedMinValue(outBitwidth).getSExtValue()));
2756+
auto max = rewriter.create<ConstantOp>(
2757+
loc, rewriter.getIntegerAttr(
2758+
accETy,
2759+
APInt::getSignedMaxValue(outBitwidth).getSExtValue()));
2760+
auto clamp = clampHelper<mlir::CmpIOp>(
2761+
loc, scaled, min, max, CmpIPredicate::slt, rewriter);
2762+
2763+
// Convert type.
2764+
poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
2765+
}
26542766

2655-
return failure();
2767+
// Cast to output type.
2768+
2769+
rewriter.create<linalg::YieldOp>(loc, poolVal);
2770+
});
2771+
2772+
rewriter.replaceOp(op, genericOp.getResult(0));
2773+
return success();
26562774
}
26572775
};
26582776

@@ -2719,8 +2837,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
27192837
TileConverter,
27202838
TransposeConverter,
27212839
MatMulConverter,
2722-
Pool2dConverter<tosa::AvgPool2dOp>,
2723-
Pool2dConverter<tosa::MaxPool2dOp>,
2840+
MaxPool2dConverter,
2841+
AvgPool2dConverter,
27242842
FullyConnectedConverter>(patterns->getContext());
27252843
// clang-format on
27262844
}

0 commit comments

Comments
 (0)