|
| 1 | +//===- TosaFoldConstantClamp.cpp ------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Fold TOSA Clamp operation on constant data |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 14 | +#include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| 15 | +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" |
| 16 | +#include "mlir/IR/Matchers.h" |
| 17 | +#include "mlir/Pass/Pass.h" |
| 18 | +#include <llvm/ADT/APFloat.h> |
| 19 | +#include <llvm/ADT/APInt.h> |
| 20 | +#include <mlir/IR/BuiltinAttributes.h> |
| 21 | +#include <mlir/IR/BuiltinTypes.h> |
| 22 | +#include <mlir/Support/LogicalResult.h> |
| 23 | + |
| 24 | +using namespace mlir; |
| 25 | +using namespace mlir::tosa; |
| 26 | + |
| 27 | +namespace { |
| 28 | + |
| 29 | +struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> { |
| 30 | + |
| 31 | + using OpRewritePattern::OpRewritePattern; |
| 32 | + |
| 33 | + static void |
| 34 | + changeSemanticsLossless(APFloat &floatVal, |
| 35 | + const llvm::fltSemantics *floatSemantics) { |
| 36 | + bool losesInfo; |
| 37 | + floatVal.convert(*floatSemantics, tosaRoundingMode, &losesInfo); |
| 38 | + assert(!losesInfo); |
| 39 | + } |
| 40 | + |
| 41 | + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, |
| 42 | + const APInt &lowerBound, const APInt &upperBound, |
| 43 | + TensorType resultType) const { |
| 44 | + |
| 45 | + // Determine the width for the APInt comparison |
| 46 | + auto comparisonWidth = |
| 47 | + std::max(inputValues.getElementType().getIntOrFloatBitWidth(), |
| 48 | + lowerBound.getBitWidth()); |
| 49 | + // Sign-extend the upper and lower bound |
| 50 | + auto extUpperBound = upperBound.sext(comparisonWidth); |
| 51 | + auto extLowerBound = lowerBound.sext(comparisonWidth); |
| 52 | + |
| 53 | + // Determine the result type |
| 54 | + auto resultingIntType = cast<IntegerType>(resultType.getElementType()); |
| 55 | + |
| 56 | + // Lambda to perform the clamp |
| 57 | + auto clampFun = [&extLowerBound, &extUpperBound, |
| 58 | + &comparisonWidth](const APInt &val, IntegerType type) { |
| 59 | + auto clampedUpper = |
| 60 | + llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound); |
| 61 | + auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound); |
| 62 | + assert(type.getWidth() >= fullyClamped.getSignificantBits()); |
| 63 | + return fullyClamped.trunc(type.getWidth()); |
| 64 | + }; |
| 65 | + auto newTensor = applyElementWise<APInt, APInt, IntegerType>( |
| 66 | + inputValues, clampFun, resultingIntType); |
| 67 | + |
| 68 | + return newTensor; |
| 69 | + } |
| 70 | + |
| 71 | + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, |
| 72 | + APFloat lowerBound, APFloat upperBound, |
| 73 | + TensorType resultType) const { |
| 74 | + auto inputValType = cast<FloatType>(inputValues.getElementType()); |
| 75 | + auto inputWidth = inputValType.getWidth(); |
| 76 | + auto bWidth = APFloat::semanticsSizeInBits(lowerBound.getSemantics()); |
| 77 | + auto *comparisonSem = inputWidth < bWidth |
| 78 | + ? &lowerBound.getSemantics() |
| 79 | + : &inputValType.getFloatSemantics(); |
| 80 | + |
| 81 | + changeSemanticsLossless(lowerBound, comparisonSem); |
| 82 | + changeSemanticsLossless(upperBound, comparisonSem); |
| 83 | + |
| 84 | + auto resultingFloatType = cast<FloatType>(resultType.getElementType()); |
| 85 | + |
| 86 | + // Ensure that the value is larger than the lower bound and smaller than the |
| 87 | + // upper bound |
| 88 | + auto clampFun = [&lowerBound, &upperBound, &comparisonSem](APFloat val, |
| 89 | + FloatType type) { |
| 90 | + if (val.isNaN()) { |
| 91 | + return APFloat::getNaN(type.getFloatSemantics()); |
| 92 | + } |
| 93 | + changeSemanticsLossless(val, comparisonSem); |
| 94 | + auto clampedUpper = val < upperBound ? val : upperBound; |
| 95 | + auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper; |
| 96 | + changeSemanticsLossless(fullyClamped, &type.getFloatSemantics()); |
| 97 | + return fullyClamped; |
| 98 | + }; |
| 99 | + auto newTensor = applyElementWise<APFloat, APFloat, FloatType>( |
| 100 | + inputValues, clampFun, resultingFloatType); |
| 101 | + |
| 102 | + return newTensor; |
| 103 | + } |
| 104 | + |
| 105 | + LogicalResult matchAndRewrite(ClampOp clampOp, |
| 106 | + PatternRewriter &rewriter) const override { |
| 107 | + auto valsToClamp = clampOp.getInput(); |
| 108 | + auto inputElementType = valsToClamp.getType().getElementType(); |
| 109 | + |
| 110 | + // Check if the input is constant |
| 111 | + if (failed(notifyIfNoTosaDenseConstantTensor(valsToClamp, clampOp, |
| 112 | + rewriter))) { |
| 113 | + return failure(); |
| 114 | + } |
| 115 | + |
| 116 | + if (isa<IntegerType>(inputElementType) && |
| 117 | + cast<IntegerType>(inputElementType).isUnsigned()) { |
| 118 | + return rewriter.notifyMatchFailure( |
| 119 | + clampOp, "Currently, unsigned integer clamps are unsupported."); |
| 120 | + } |
| 121 | + |
| 122 | + // Extract the tensor values |
| 123 | + DenseElementsAttr inputValues; |
| 124 | + matchPattern(valsToClamp, m_Constant(&inputValues)); |
| 125 | + |
| 126 | + if (!constantUnaryOpShouldBeFolded(clampOp, inputValues)) { |
| 127 | + return rewriter.notifyMatchFailure( |
| 128 | + clampOp, |
| 129 | + "Currently, clamps will only be folded if this requires only " |
| 130 | + "little additional memory usage."); |
| 131 | + } |
| 132 | + |
| 133 | + // Apply the clamp to all values of the int/float tensor |
| 134 | + auto resultType = clampOp.getType(); |
| 135 | + DenseElementsAttr newTensor; |
| 136 | + if (isa<IntegerType>(inputElementType)) { |
| 137 | + auto lowerBoundVal = clampOp.getMinIntAttr().getValue(); |
| 138 | + auto upperBoundVal = clampOp.getMaxIntAttr().getValue(); |
| 139 | + assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth()); |
| 140 | + |
| 141 | + newTensor = |
| 142 | + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); |
| 143 | + } else { |
| 144 | + assert(isa<FloatType>(inputElementType)); |
| 145 | + auto lowerBoundVal = clampOp.getMinFp(); |
| 146 | + auto upperBoundVal = clampOp.getMaxFp(); |
| 147 | + assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) == |
| 148 | + APFloat::getSizeInBits(upperBoundVal.getSemantics())); |
| 149 | + |
| 150 | + newTensor = |
| 151 | + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); |
| 152 | + } |
| 153 | + |
| 154 | + rewriter.replaceOpWithNewOp<ConstOp>(clampOp, newTensor.getType(), |
| 155 | + newTensor); |
| 156 | + |
| 157 | + return success(); |
| 158 | + } |
| 159 | +}; |
| 160 | + |
| 161 | +} // namespace |
| 162 | + |
| 163 | +void mlir::tosa::populateTosaFoldConstantClampPatterns( |
| 164 | + MLIRContext *ctx, RewritePatternSet &patterns) { |
| 165 | + patterns.add<TosaFoldConstantClamp>(ctx); |
| 166 | +} |
0 commit comments