Skip to content

Commit dc4786b

Browse files
lipracerjsetoain
andauthored
[mlir][affine] remove divide zero check when simplifer affineMap (llvm#64622) (llvm#68519)
When performing constant folding on the affineApplyOp, there is a division of 0 in the affine map. [related issue](llvm#64622) --------- Co-authored-by: Javier Setoain <[email protected]>
1 parent c093383 commit dc4786b

File tree

8 files changed

+255
-112
lines changed

8 files changed

+255
-112
lines changed

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 134 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_IR_AFFINEEXPRVISITOR_H
1515

1616
#include "mlir/IR/AffineExpr.h"
17+
#include "mlir/Support/LogicalResult.h"
1718
#include "llvm/ADT/ArrayRef.h"
1819

1920
namespace mlir {
@@ -65,8 +66,78 @@ namespace mlir {
6566
/// just as efficient as having your own switch instruction over the instruction
6667
/// opcode.
6768

69+
template <typename SubClass, typename RetTy>
70+
class AffineExprVisitorBase {
71+
public:
72+
// Function to visit an AffineExpr.
73+
RetTy visit(AffineExpr expr) {
74+
static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
75+
"Must instantiate with a derived type of AffineExprVisitor");
76+
auto self = static_cast<SubClass *>(this);
77+
switch (expr.getKind()) {
78+
case AffineExprKind::Add: {
79+
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
80+
return self->visitAddExpr(binOpExpr);
81+
}
82+
case AffineExprKind::Mul: {
83+
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
84+
return self->visitMulExpr(binOpExpr);
85+
}
86+
case AffineExprKind::Mod: {
87+
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
88+
return self->visitModExpr(binOpExpr);
89+
}
90+
case AffineExprKind::FloorDiv: {
91+
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
92+
return self->visitFloorDivExpr(binOpExpr);
93+
}
94+
case AffineExprKind::CeilDiv: {
95+
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
96+
return self->visitCeilDivExpr(binOpExpr);
97+
}
98+
case AffineExprKind::Constant:
99+
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
100+
case AffineExprKind::DimId:
101+
return self->visitDimExpr(cast<AffineDimExpr>(expr));
102+
case AffineExprKind::SymbolId:
103+
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
104+
}
105+
llvm_unreachable("Unknown AffineExpr");
106+
}
107+
108+
//===--------------------------------------------------------------------===//
109+
// Visitation functions... these functions provide default fallbacks in case
110+
// the user does not specify what to do for a particular instruction type.
111+
// The default behavior is to generalize the instruction type to its subtype
112+
// and try visiting the subtype. All of this should be inlined perfectly,
113+
// because there are no virtual functions to get in the way.
114+
//
115+
116+
// Default visit methods. Note that the default op-specific binary op visit
117+
// methods call the general visitAffineBinaryOpExpr visit method.
118+
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
119+
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
120+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
121+
}
122+
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
123+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
124+
}
125+
RetTy visitModExpr(AffineBinaryOpExpr expr) {
126+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
127+
}
128+
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
129+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
130+
}
131+
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
132+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
133+
}
134+
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
135+
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
136+
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
137+
};
138+
68139
template <typename SubClass, typename RetTy = void>
69-
class AffineExprVisitor {
140+
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
70141
//===--------------------------------------------------------------------===//
71142
// Interface code - This is the public interface of the AffineExprVisitor
72143
// that you use to visit affine expressions...
@@ -75,117 +146,112 @@ class AffineExprVisitor {
75146
RetTy walkPostOrder(AffineExpr expr) {
76147
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
77148
"Must instantiate with a derived type of AffineExprVisitor");
149+
auto self = static_cast<SubClass *>(this);
78150
switch (expr.getKind()) {
79151
case AffineExprKind::Add: {
80152
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
81153
walkOperandsPostOrder(binOpExpr);
82-
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
154+
return self->visitAddExpr(binOpExpr);
83155
}
84156
case AffineExprKind::Mul: {
85157
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
86158
walkOperandsPostOrder(binOpExpr);
87-
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
159+
return self->visitMulExpr(binOpExpr);
88160
}
89161
case AffineExprKind::Mod: {
90162
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
91163
walkOperandsPostOrder(binOpExpr);
92-
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
164+
return self->visitModExpr(binOpExpr);
93165
}
94166
case AffineExprKind::FloorDiv: {
95167
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
96168
walkOperandsPostOrder(binOpExpr);
97-
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
169+
return self->visitFloorDivExpr(binOpExpr);
98170
}
99171
case AffineExprKind::CeilDiv: {
100172
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
101173
walkOperandsPostOrder(binOpExpr);
102-
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
174+
return self->visitCeilDivExpr(binOpExpr);
103175
}
104176
case AffineExprKind::Constant:
105-
return static_cast<SubClass *>(this)->visitConstantExpr(
106-
cast<AffineConstantExpr>(expr));
177+
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
107178
case AffineExprKind::DimId:
108-
return static_cast<SubClass *>(this)->visitDimExpr(
109-
cast<AffineDimExpr>(expr));
179+
return self->visitDimExpr(cast<AffineDimExpr>(expr));
110180
case AffineExprKind::SymbolId:
111-
return static_cast<SubClass *>(this)->visitSymbolExpr(
112-
cast<AffineSymbolExpr>(expr));
181+
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
113182
}
183+
llvm_unreachable("Unknown AffineExpr");
114184
}
115185

116-
// Function to visit an AffineExpr.
117-
RetTy visit(AffineExpr expr) {
186+
private:
187+
// Walk the operands - each operand is itself walked in post order.
188+
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
189+
walkPostOrder(expr.getLHS());
190+
walkPostOrder(expr.getRHS());
191+
}
192+
};
193+
194+
template <typename SubClass>
195+
class AffineExprVisitor<SubClass, LogicalResult>
196+
: public AffineExprVisitorBase<SubClass, LogicalResult> {
197+
//===--------------------------------------------------------------------===//
198+
// Interface code - This is the public interface of the AffineExprVisitor
199+
// that you use to visit affine expressions...
200+
public:
201+
// Function to walk an AffineExpr (in post order).
202+
LogicalResult walkPostOrder(AffineExpr expr) {
118203
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
119204
"Must instantiate with a derived type of AffineExprVisitor");
205+
auto self = static_cast<SubClass *>(this);
120206
switch (expr.getKind()) {
121207
case AffineExprKind::Add: {
122208
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
123-
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
209+
if (failed(walkOperandsPostOrder(binOpExpr)))
210+
return failure();
211+
return self->visitAddExpr(binOpExpr);
124212
}
125213
case AffineExprKind::Mul: {
126214
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
127-
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
215+
if (failed(walkOperandsPostOrder(binOpExpr)))
216+
return failure();
217+
return self->visitMulExpr(binOpExpr);
128218
}
129219
case AffineExprKind::Mod: {
130220
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
131-
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
221+
if (failed(walkOperandsPostOrder(binOpExpr)))
222+
return failure();
223+
return self->visitModExpr(binOpExpr);
132224
}
133225
case AffineExprKind::FloorDiv: {
134226
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
135-
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
227+
if (failed(walkOperandsPostOrder(binOpExpr)))
228+
return failure();
229+
return self->visitFloorDivExpr(binOpExpr);
136230
}
137231
case AffineExprKind::CeilDiv: {
138232
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
139-
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
233+
if (failed(walkOperandsPostOrder(binOpExpr)))
234+
return failure();
235+
return self->visitCeilDivExpr(binOpExpr);
140236
}
141237
case AffineExprKind::Constant:
142-
return static_cast<SubClass *>(this)->visitConstantExpr(
143-
cast<AffineConstantExpr>(expr));
238+
return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
144239
case AffineExprKind::DimId:
145-
return static_cast<SubClass *>(this)->visitDimExpr(
146-
cast<AffineDimExpr>(expr));
240+
return self->visitDimExpr(cast<AffineDimExpr>(expr));
147241
case AffineExprKind::SymbolId:
148-
return static_cast<SubClass *>(this)->visitSymbolExpr(
149-
cast<AffineSymbolExpr>(expr));
242+
return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
150243
}
151244
llvm_unreachable("Unknown AffineExpr");
152245
}
153246

154-
//===--------------------------------------------------------------------===//
155-
// Visitation functions... these functions provide default fallbacks in case
156-
// the user does not specify what to do for a particular instruction type.
157-
// The default behavior is to generalize the instruction type to its subtype
158-
// and try visiting the subtype. All of this should be inlined perfectly,
159-
// because there are no virtual functions to get in the way.
160-
//
161-
162-
// Default visit methods. Note that the default op-specific binary op visit
163-
// methods call the general visitAffineBinaryOpExpr visit method.
164-
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
165-
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
166-
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
167-
}
168-
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
169-
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
170-
}
171-
RetTy visitModExpr(AffineBinaryOpExpr expr) {
172-
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
173-
}
174-
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
175-
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
176-
}
177-
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
178-
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
179-
}
180-
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
181-
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
182-
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
183-
184247
private:
185248
// Walk the operands - each operand is itself walked in post order.
186-
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
187-
walkPostOrder(expr.getLHS());
188-
walkPostOrder(expr.getRHS());
249+
LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
250+
if (failed(walkPostOrder(expr.getLHS())))
251+
return failure();
252+
if (failed(walkPostOrder(expr.getRHS())))
253+
return failure();
254+
return success();
189255
}
190256
};
191257

@@ -246,7 +312,7 @@ class AffineExprVisitor {
246312
// expressions are mapped to the same local identifier (same column position in
247313
// 'localVarCst').
248314
class SimpleAffineExprFlattener
249-
: public AffineExprVisitor<SimpleAffineExprFlattener> {
315+
: public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
250316
public:
251317
// Flattend expression layout: [dims, symbols, locals, constant]
252318
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,21 +341,21 @@ class SimpleAffineExprFlattener
275341
virtual ~SimpleAffineExprFlattener() = default;
276342

277343
// Visitor method overrides.
278-
void visitMulExpr(AffineBinaryOpExpr expr);
279-
void visitAddExpr(AffineBinaryOpExpr expr);
280-
void visitDimExpr(AffineDimExpr expr);
281-
void visitSymbolExpr(AffineSymbolExpr expr);
282-
void visitConstantExpr(AffineConstantExpr expr);
283-
void visitCeilDivExpr(AffineBinaryOpExpr expr);
284-
void visitFloorDivExpr(AffineBinaryOpExpr expr);
344+
LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
345+
LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
346+
LogicalResult visitDimExpr(AffineDimExpr expr);
347+
LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
348+
LogicalResult visitConstantExpr(AffineConstantExpr expr);
349+
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
350+
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
285351

286352
//
287353
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
288354
//
289355
// A mod expression "expr mod c" is thus flattened by introducing a new local
290356
// variable q (= expr floordiv c), such that expr mod c is replaced with
291357
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
292-
void visitModExpr(AffineBinaryOpExpr expr);
358+
LogicalResult visitModExpr(AffineBinaryOpExpr expr);
293359

294360
protected:
295361
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +394,7 @@ class SimpleAffineExprFlattener
328394
//
329395
// A ceildiv is similarly flattened:
330396
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
331-
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
397+
LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
332398

333399
int findLocalId(AffineExpr localExpr);
334400

mlir/include/mlir/IR/AffineMap.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,18 @@ class AffineMap {
310310
/// Folds the results of the application of an affine map on the provided
311311
/// operands to a constant if possible.
312312
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
313-
SmallVectorImpl<Attribute> &results) const;
313+
SmallVectorImpl<Attribute> &results,
314+
bool *hasPoison = nullptr) const;
314315

315316
/// Propagates the constant operands into this affine map. Operands are
316317
/// allowed to be null, at which point they are treated as non-constant. This
317318
/// does not change the number of symbols and dimensions. Returns a new map,
318319
/// which may be equal to the old map if no folding happened. If `results` is
319320
/// provided and if all expressions in the map were folded to constants,
320321
/// `results` will contain the values of these constants.
321-
AffineMap
322-
partialConstantFold(ArrayRef<Attribute> operandConstants,
323-
SmallVectorImpl<int64_t> *results = nullptr) const;
322+
AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants,
323+
SmallVectorImpl<int64_t> *results = nullptr,
324+
bool *hasPoison = nullptr) const;
324325

325326
/// Returns the AffineMap resulting from composing `this` with `map`.
326327
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
6767
} // namespace
6868

6969
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
70-
// flattened (i.e., semi-affine expressions not handled yet).
70+
// flattened. For example two specific cases:
71+
// 1. semi-affine expressions not handled yet.
72+
// 2. has poison expression (i.e., division by zero).
7173
static LogicalResult
7274
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
7375
unsigned numSymbols,
@@ -85,8 +87,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
8587
for (auto expr : exprs) {
8688
if (!expr.isPureAffine())
8789
return failure();
88-
89-
flattener.walkPostOrder(expr);
90+
// has poison expression
91+
auto flattenResult = flattener.walkPostOrder(expr);
92+
if (failed(flattenResult))
93+
return failure();
9094
}
9195

9296
assert(flattener.operandExprStack.size() == exprs.size());

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
1111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12+
#include "mlir/Dialect/UB/IR/UBOps.h"
1213
#include "mlir/IR/AffineExprVisitor.h"
1314
#include "mlir/IR/IRMapping.h"
1415
#include "mlir/IR/IntegerSet.h"
@@ -226,6 +227,8 @@ void AffineDialect::initialize() {
226227
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
227228
Attribute value, Type type,
228229
Location loc) {
230+
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
231+
return builder.create<ub::PoisonOp>(loc, type, poison);
229232
return arith::ConstantOp::materialize(builder, value, type, loc);
230233
}
231234

@@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
580583

581584
// Otherwise, default to folding the map.
582585
SmallVector<Attribute, 1> result;
583-
if (failed(map.constantFold(adaptor.getMapOperands(), result)))
586+
bool hasPoison = false;
587+
auto foldResult =
588+
map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
589+
if (hasPoison)
590+
return ub::PoisonAttr::get(getContext());
591+
if (failed(foldResult))
584592
return {};
585593
return result[0];
586594
}
@@ -3379,7 +3387,9 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
33793387
return failure();
33803388

33813389
SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3382-
flattener.walkPostOrder(resultExpr);
3390+
auto flattenResult = flattener.walkPostOrder(resultExpr);
3391+
if (failed(flattenResult))
3392+
return failure();
33833393

33843394
// Fail if the flattened expression has local variables.
33853395
if (flattener.operandExprStack.back().size() !=

0 commit comments

Comments
 (0)