Skip to content

[mlir][affine] remove divide zero check when simplifer affineMap (#64622) #68519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 134 additions & 68 deletions mlir/include/mlir/IR/AffineExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_IR_AFFINEEXPRVISITOR_H

#include "mlir/IR/AffineExpr.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {
Expand Down Expand Up @@ -65,8 +66,78 @@ namespace mlir {
/// just as efficient as having your own switch instruction over the instruction
/// opcode.

template <typename SubClass, typename RetTy>
class AffineExprVisitorBase {
public:
// Function to visit an AffineExpr.
RetTy visit(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
case AffineExprKind::DimId:
return self->visitDimExpr(cast<AffineDimExpr>(expr));
case AffineExprKind::SymbolId:
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
}
llvm_unreachable("Unknown AffineExpr");
}

//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//

// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitModExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
};

template <typename SubClass, typename RetTy = void>
class AffineExprVisitor {
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
Expand All @@ -75,117 +146,112 @@ class AffineExprVisitor {
RetTy walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
cast<AffineConstantExpr>(expr));
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
cast<AffineDimExpr>(expr));
return self->visitDimExpr(cast<AffineDimExpr>(expr));
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
cast<AffineSymbolExpr>(expr));
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
}
llvm_unreachable("Unknown AffineExpr");
}

// Function to visit an AffineExpr.
RetTy visit(AffineExpr expr) {
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
}
};

template <typename SubClass>
class AffineExprVisitor<SubClass, LogicalResult>
: public AffineExprVisitorBase<SubClass, LogicalResult> {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
public:
// Function to walk an AffineExpr (in post order).
LogicalResult walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
cast<AffineConstantExpr>(expr));
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
cast<AffineDimExpr>(expr));
return self->visitDimExpr(cast<AffineDimExpr>(expr));
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
cast<AffineSymbolExpr>(expr));
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
}
llvm_unreachable("Unknown AffineExpr");
}

//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//

// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitModExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }

private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
if (failed(walkPostOrder(expr.getLHS())))
return failure();
if (failed(walkPostOrder(expr.getRHS())))
return failure();
return success();
}
};

Expand Down Expand Up @@ -246,7 +312,7 @@ class AffineExprVisitor {
// expressions are mapped to the same local identifier (same column position in
// 'localVarCst').
class SimpleAffineExprFlattener
: public AffineExprVisitor<SimpleAffineExprFlattener> {
: public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
Expand Down Expand Up @@ -275,21 +341,21 @@ class SimpleAffineExprFlattener
virtual ~SimpleAffineExprFlattener() = default;

// Visitor method overrides.
void visitMulExpr(AffineBinaryOpExpr expr);
void visitAddExpr(AffineBinaryOpExpr expr);
void visitDimExpr(AffineDimExpr expr);
void visitSymbolExpr(AffineSymbolExpr expr);
void visitConstantExpr(AffineConstantExpr expr);
void visitCeilDivExpr(AffineBinaryOpExpr expr);
void visitFloorDivExpr(AffineBinaryOpExpr expr);
LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
LogicalResult visitDimExpr(AffineDimExpr expr);
LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
LogicalResult visitConstantExpr(AffineConstantExpr expr);
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);

//
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
//
// A mod expression "expr mod c" is thus flattened by introducing a new local
// variable q (= expr floordiv c), such that expr mod c is replaced with
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
void visitModExpr(AffineBinaryOpExpr expr);
LogicalResult visitModExpr(AffineBinaryOpExpr expr);

protected:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
Expand Down Expand Up @@ -328,7 +394,7 @@ class SimpleAffineExprFlattener
//
// A ceildiv is similarly flattened:
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);

int findLocalId(AffineExpr localExpr);

Expand Down
9 changes: 5 additions & 4 deletions mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,18 @@ class AffineMap {
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible.
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const;
SmallVectorImpl<Attribute> &results,
bool *hasPoison = nullptr) const;

/// Propagates the constant operands into this affine map. Operands are
/// allowed to be null, at which point they are treated as non-constant. This
/// does not change the number of symbols and dimensions. Returns a new map,
/// which may be equal to the old map if no folding happened. If `results` is
/// provided and if all expressions in the map were folded to constants,
/// `results` will contain the values of these constants.
AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<int64_t> *results = nullptr) const;
AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<int64_t> *results = nullptr,
bool *hasPoison = nullptr) const;

/// Returns the AffineMap resulting from composing `this` with `map`.
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
} // namespace

// Flattens the expressions in map. Returns failure if 'expr' was unable to be
// flattened (i.e., semi-affine expressions not handled yet).
// flattened. For example two specific cases:
// 1. semi-affine expressions not handled yet.
// 2. has poison expression (i.e., division by zero).
static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
unsigned numSymbols,
Expand All @@ -85,8 +87,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
for (auto expr : exprs) {
if (!expr.isPureAffine())
return failure();

flattener.walkPostOrder(expr);
// has poison expression
auto flattenResult = flattener.walkPostOrder(expr);
if (failed(flattenResult))
return failure();
}

assert(flattener.operandExprStack.size() == exprs.size());
Expand Down
14 changes: 12 additions & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
Expand Down Expand Up @@ -226,6 +227,8 @@ void AffineDialect::initialize() {
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

Expand Down Expand Up @@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {

// Otherwise, default to folding the map.
SmallVector<Attribute, 1> result;
if (failed(map.constantFold(adaptor.getMapOperands(), result)))
bool hasPoison = false;
auto foldResult =
map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
if (hasPoison)
return ub::PoisonAttr::get(getContext());
if (failed(foldResult))
return {};
return result[0];
}
Expand Down Expand Up @@ -3379,7 +3387,9 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
return failure();

SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
flattener.walkPostOrder(resultExpr);
auto flattenResult = flattener.walkPostOrder(resultExpr);
if (failed(flattenResult))
return failure();

// Fail if the flattened expression has local variables.
if (flattener.operandExprStack.back().size() !=
Expand Down
Loading