Skip to content

[mlir][tosa] Constant folding for reciprocal #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
RewritePatternSet &patterns);

Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===- FoldUtils.h - Helper Functions for Folds -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper functions useful for various different TOSA constant folds.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H
#define MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H

#include <functional>
#include <mlir/IR/BuiltinAttributes.h>

namespace mlir {
namespace tosa {

/// Rounding mode to be used on floating point operations that require rounding.
static constexpr llvm::RoundingMode tosaRoundingMode =
llvm::APFloat::rmNearestTiesToEven;

/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
/// Elements of \p toTransform are extracted as \p SrcValueType.
///
/// \returns A tensor with the same size as \p toTransform, containing
/// \p TargetValueType values of type \p TargetType.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType);

} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldCommon.cpp
TosaFoldConstantReciprocal.cpp
TosaFoldConstantTranspose.cpp
TosaInferShapes.cpp
TosaLayerwiseConstantFoldPass.cpp
Expand Down
83 changes: 83 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===- TosaFoldCommon.cpp -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper functions useful for various different TOSA constant folds.
//
//===----------------------------------------------------------------------===//

#include "TosaFoldCommon.h"
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Matchers.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

LogicalResult
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
if (failed(floatCheck)) {
return floatCheck;
}
return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
}

LogicalResult
mlir::tosa::notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
// Check whether the tensor is constant and dense
// TODO We currently ensure the tensor is dense by using the correct type for
// the bind_value, however we do not actually need this value. It would be
// nicer to only have a check here.
DenseElementsAttr tmp;
if (!matchPattern(toCheck, m_Constant(&tmp))) {
return rewriter.notifyMatchFailure(location,
"Non-const or non-dense input tensor");
}

// Make sure it actually is a TOSA constant (the match allows for other
// constants as well)
if (isa<ConstOp>(toCheck.getDefiningOp())) {
return success();
}

return rewriter.notifyMatchFailure(location,
"The reciprocal can only be folded if "
"it operates on a TOSA constant");
}

LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
if (isa<FloatType>(toCheck.getType().getElementType())) {
return success();
}
return rewriter.notifyMatchFailure(location,
"Unexpected input tensor type: the "
"TOSA spec only allows floats");
}

bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp,
DenseElementsAttr values) {
assert(unaryOp->getNumOperands() == 1);
auto inputOp = unaryOp->getOperand(0);

// If the input is a splat, we don't care for the number of users
if (isa<SplatElementsAttr>(values)) {
return true;
}

// If this is the only use of the tensor it should be replaced as no
// additional memory is required
return inputOp.hasOneUse();
}
49 changes: 49 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===- TosaFoldCommon.h - Helper Functions for Folds ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper functions useful for various different TOSA constant folds.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H

#include <llvm/ADT/APFloat.h>
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
#include <mlir/IR/PatternMatch.h>

namespace mlir {
namespace tosa {

/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter);

/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter);

/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
PatternRewriter &rewriter);

/// Heuristic to decide when to replace a unary operation on a constant with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
/// this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p unaryOp and the constant input \p values.
/// \returns Whether folding should be applied.
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values);

} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
80 changes: 80 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//===- TosaFoldConstantReciprocal.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA reciprocal operation on constant data
//
//===----------------------------------------------------------------------===//

#include "TosaFoldCommon.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/FoldUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
recip.divide(floatVal, tosaRoundingMode);

return recip;
}

LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();

// Check that we can apply folding
auto preCondCheck =
notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
if (failed(preCondCheck)) {
return preCondCheck;
}

// Extract the tensor values
DenseElementsAttr inputValues;
matchPattern(inputTensor, m_Constant(&inputValues));

// Check whether this should be folded.
if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
return rewriter.notifyMatchFailure(
recip, "Currently, reciprocals will only be folded if the input "
"tensor has a single user");
}

// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeReciprocal,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantReciprocal>(ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
RewritePatternSet patterns(ctx);
auto func = getOperation();

mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
populateTosaOpsCanonicalizationPatterns(ctx, patterns);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_library(MLIRDialectUtils
FoldUtils.cpp
IndexingUtils.cpp
ReshapeOpsUtils.cpp
StructuredOpsUtils.cpp
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===- FoldUtils.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper functions useful for various different TOSA constant folds.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Utils/FoldUtils.h"

#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/SmallVector.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>

using namespace mlir;
using namespace mlir::tosa;

template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
transformedValues.push_back(transformedVal);
}

// Make sure that the output tensor has the expected output type
auto inShape = toTransform.getType();
auto outTy = inShape.cloneWith({}, targetType);

return DenseElementsAttr::get(outTy, transformedValues);
}

template DenseElementsAttr
mlir::tosa::applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
FloatType targetType);
Loading