14
14
#define MLIR_IR_AFFINEEXPRVISITOR_H
15
15
16
16
#include " mlir/IR/AffineExpr.h"
17
+ #include " mlir/Support/LogicalResult.h"
17
18
#include " llvm/ADT/ArrayRef.h"
18
19
19
20
namespace mlir {
@@ -65,8 +66,78 @@ namespace mlir {
65
66
// / just as efficient as having your own switch instruction over the instruction
66
67
// / opcode.
67
68
69
+ template <typename SubClass, typename RetTy>
70
+ class AffineExprVisitorBase {
71
+ public:
72
+ // Function to visit an AffineExpr.
73
+ RetTy visit (AffineExpr expr) {
74
+ static_assert (std::is_base_of<AffineExprVisitorBase, SubClass>::value,
75
+ " Must instantiate with a derived type of AffineExprVisitor" );
76
+ auto self = static_cast <SubClass *>(this );
77
+ switch (expr.getKind ()) {
78
+ case AffineExprKind::Add: {
79
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
80
+ return self->visitAddExpr (binOpExpr);
81
+ }
82
+ case AffineExprKind::Mul: {
83
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
84
+ return self->visitMulExpr (binOpExpr);
85
+ }
86
+ case AffineExprKind::Mod: {
87
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
88
+ return self->visitModExpr (binOpExpr);
89
+ }
90
+ case AffineExprKind::FloorDiv: {
91
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
92
+ return self->visitFloorDivExpr (binOpExpr);
93
+ }
94
+ case AffineExprKind::CeilDiv: {
95
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
96
+ return self->visitCeilDivExpr (binOpExpr);
97
+ }
98
+ case AffineExprKind::Constant:
99
+ return self->visitConstantExpr (cast<AffineConstantExpr>(expr));
100
+ case AffineExprKind::DimId:
101
+ return self->visitDimExpr (cast<AffineDimExpr>(expr));
102
+ case AffineExprKind::SymbolId:
103
+ return self->visitSymbolExpr (cast<AffineSymbolExpr>(expr));
104
+ }
105
+ llvm_unreachable (" Unknown AffineExpr" );
106
+ }
107
+
108
+ // ===--------------------------------------------------------------------===//
109
+ // Visitation functions... these functions provide default fallbacks in case
110
+ // the user does not specify what to do for a particular instruction type.
111
+ // The default behavior is to generalize the instruction type to its subtype
112
+ // and try visiting the subtype. All of this should be inlined perfectly,
113
+ // because there are no virtual functions to get in the way.
114
+ //
115
+
116
+ // Default visit methods. Note that the default op-specific binary op visit
117
+ // methods call the general visitAffineBinaryOpExpr visit method.
118
+ RetTy visitAffineBinaryOpExpr (AffineBinaryOpExpr expr) { return RetTy (); }
119
+ RetTy visitAddExpr (AffineBinaryOpExpr expr) {
120
+ return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
121
+ }
122
+ RetTy visitMulExpr (AffineBinaryOpExpr expr) {
123
+ return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
124
+ }
125
+ RetTy visitModExpr (AffineBinaryOpExpr expr) {
126
+ return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
127
+ }
128
+ RetTy visitFloorDivExpr (AffineBinaryOpExpr expr) {
129
+ return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
130
+ }
131
+ RetTy visitCeilDivExpr (AffineBinaryOpExpr expr) {
132
+ return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
133
+ }
134
+ RetTy visitConstantExpr (AffineConstantExpr expr) { return RetTy (); }
135
+ RetTy visitDimExpr (AffineDimExpr expr) { return RetTy (); }
136
+ RetTy visitSymbolExpr (AffineSymbolExpr expr) { return RetTy (); }
137
+ };
138
+
68
139
template <typename SubClass, typename RetTy = void >
69
- class AffineExprVisitor {
140
+ class AffineExprVisitor : public AffineExprVisitorBase <SubClass, RetTy> {
70
141
// ===--------------------------------------------------------------------===//
71
142
// Interface code - This is the public interface of the AffineExprVisitor
72
143
// that you use to visit affine expressions...
@@ -75,117 +146,112 @@ class AffineExprVisitor {
75
146
RetTy walkPostOrder (AffineExpr expr) {
76
147
static_assert (std::is_base_of<AffineExprVisitor, SubClass>::value,
77
148
" Must instantiate with a derived type of AffineExprVisitor" );
149
+ auto self = static_cast <SubClass *>(this );
78
150
switch (expr.getKind ()) {
79
151
case AffineExprKind::Add: {
80
152
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
81
153
walkOperandsPostOrder (binOpExpr);
82
- return static_cast <SubClass *>( this ) ->visitAddExpr (binOpExpr);
154
+ return self ->visitAddExpr (binOpExpr);
83
155
}
84
156
case AffineExprKind::Mul: {
85
157
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
86
158
walkOperandsPostOrder (binOpExpr);
87
- return static_cast <SubClass *>( this ) ->visitMulExpr (binOpExpr);
159
+ return self ->visitMulExpr (binOpExpr);
88
160
}
89
161
case AffineExprKind::Mod: {
90
162
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
91
163
walkOperandsPostOrder (binOpExpr);
92
- return static_cast <SubClass *>( this ) ->visitModExpr (binOpExpr);
164
+ return self ->visitModExpr (binOpExpr);
93
165
}
94
166
case AffineExprKind::FloorDiv: {
95
167
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
96
168
walkOperandsPostOrder (binOpExpr);
97
- return static_cast <SubClass *>( this ) ->visitFloorDivExpr (binOpExpr);
169
+ return self ->visitFloorDivExpr (binOpExpr);
98
170
}
99
171
case AffineExprKind::CeilDiv: {
100
172
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
101
173
walkOperandsPostOrder (binOpExpr);
102
- return static_cast <SubClass *>( this ) ->visitCeilDivExpr (binOpExpr);
174
+ return self ->visitCeilDivExpr (binOpExpr);
103
175
}
104
176
case AffineExprKind::Constant:
105
- return static_cast <SubClass *>(this )->visitConstantExpr (
106
- cast<AffineConstantExpr>(expr));
177
+ return self->visitConstantExpr (cast<AffineConstantExpr>(expr));
107
178
case AffineExprKind::DimId:
108
- return static_cast <SubClass *>(this )->visitDimExpr (
109
- cast<AffineDimExpr>(expr));
179
+ return self->visitDimExpr (cast<AffineDimExpr>(expr));
110
180
case AffineExprKind::SymbolId:
111
- return static_cast <SubClass *>(this )->visitSymbolExpr (
112
- cast<AffineSymbolExpr>(expr));
181
+ return self->visitSymbolExpr (cast<AffineSymbolExpr>(expr));
113
182
}
183
+ llvm_unreachable (" Unknown AffineExpr" );
114
184
}
115
185
116
- // Function to visit an AffineExpr.
117
- RetTy visit (AffineExpr expr) {
186
+ private:
187
+ // Walk the operands - each operand is itself walked in post order.
188
+ RetTy walkOperandsPostOrder (AffineBinaryOpExpr expr) {
189
+ walkPostOrder (expr.getLHS ());
190
+ walkPostOrder (expr.getRHS ());
191
+ }
192
+ };
193
+
194
+ template <typename SubClass>
195
+ class AffineExprVisitor <SubClass, LogicalResult>
196
+ : public AffineExprVisitorBase<SubClass, LogicalResult> {
197
+ // ===--------------------------------------------------------------------===//
198
+ // Interface code - This is the public interface of the AffineExprVisitor
199
+ // that you use to visit affine expressions...
200
+ public:
201
+ // Function to walk an AffineExpr (in post order).
202
+ LogicalResult walkPostOrder (AffineExpr expr) {
118
203
static_assert (std::is_base_of<AffineExprVisitor, SubClass>::value,
119
204
" Must instantiate with a derived type of AffineExprVisitor" );
205
+ auto self = static_cast <SubClass *>(this );
120
206
switch (expr.getKind ()) {
121
207
case AffineExprKind::Add: {
122
208
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
123
- return static_cast <SubClass *>(this )->visitAddExpr (binOpExpr);
209
+ if (failed (walkOperandsPostOrder (binOpExpr)))
210
+ return failure ();
211
+ return self->visitAddExpr (binOpExpr);
124
212
}
125
213
case AffineExprKind::Mul: {
126
214
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
127
- return static_cast <SubClass *>(this )->visitMulExpr (binOpExpr);
215
+ if (failed (walkOperandsPostOrder (binOpExpr)))
216
+ return failure ();
217
+ return self->visitMulExpr (binOpExpr);
128
218
}
129
219
case AffineExprKind::Mod: {
130
220
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
131
- return static_cast <SubClass *>(this )->visitModExpr (binOpExpr);
221
+ if (failed (walkOperandsPostOrder (binOpExpr)))
222
+ return failure ();
223
+ return self->visitModExpr (binOpExpr);
132
224
}
133
225
case AffineExprKind::FloorDiv: {
134
226
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
135
- return static_cast <SubClass *>(this )->visitFloorDivExpr (binOpExpr);
227
+ if (failed (walkOperandsPostOrder (binOpExpr)))
228
+ return failure ();
229
+ return self->visitFloorDivExpr (binOpExpr);
136
230
}
137
231
case AffineExprKind::CeilDiv: {
138
232
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
139
- return static_cast <SubClass *>(this )->visitCeilDivExpr (binOpExpr);
233
+ if (failed (walkOperandsPostOrder (binOpExpr)))
234
+ return failure ();
235
+ return self->visitCeilDivExpr (binOpExpr);
140
236
}
141
237
case AffineExprKind::Constant:
142
- return static_cast <SubClass *>(this )->visitConstantExpr (
143
- cast<AffineConstantExpr>(expr));
238
+ return self->visitConstantExpr (cast<AffineConstantExpr>(expr));
144
239
case AffineExprKind::DimId:
145
- return static_cast <SubClass *>(this )->visitDimExpr (
146
- cast<AffineDimExpr>(expr));
240
+ return self->visitDimExpr (cast<AffineDimExpr>(expr));
147
241
case AffineExprKind::SymbolId:
148
- return static_cast <SubClass *>(this )->visitSymbolExpr (
149
- cast<AffineSymbolExpr>(expr));
242
+ return self->visitSymbolExpr (cast<AffineSymbolExpr>(expr));
150
243
}
151
244
llvm_unreachable (" Unknown AffineExpr" );
152
245
}
153
246
154
- // ===--------------------------------------------------------------------===//
155
- // Visitation functions... these functions provide default fallbacks in case
156
- // the user does not specify what to do for a particular instruction type.
157
- // The default behavior is to generalize the instruction type to its subtype
158
- // and try visiting the subtype. All of this should be inlined perfectly,
159
- // because there are no virtual functions to get in the way.
160
- //
161
-
162
- // Default visit methods. Note that the default op-specific binary op visit
163
- // methods call the general visitAffineBinaryOpExpr visit method.
164
- RetTy visitAffineBinaryOpExpr (AffineBinaryOpExpr expr) { return RetTy (); }
165
- RetTy visitAddExpr (AffineBinaryOpExpr expr) {
166
- return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
167
- }
168
- RetTy visitMulExpr (AffineBinaryOpExpr expr) {
169
- return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
170
- }
171
- RetTy visitModExpr (AffineBinaryOpExpr expr) {
172
- return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
173
- }
174
- RetTy visitFloorDivExpr (AffineBinaryOpExpr expr) {
175
- return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
176
- }
177
- RetTy visitCeilDivExpr (AffineBinaryOpExpr expr) {
178
- return static_cast <SubClass *>(this )->visitAffineBinaryOpExpr (expr);
179
- }
180
- RetTy visitConstantExpr (AffineConstantExpr expr) { return RetTy (); }
181
- RetTy visitDimExpr (AffineDimExpr expr) { return RetTy (); }
182
- RetTy visitSymbolExpr (AffineSymbolExpr expr) { return RetTy (); }
183
-
184
247
private:
185
248
// Walk the operands - each operand is itself walked in post order.
186
- RetTy walkOperandsPostOrder (AffineBinaryOpExpr expr) {
187
- walkPostOrder (expr.getLHS ());
188
- walkPostOrder (expr.getRHS ());
249
+ LogicalResult walkOperandsPostOrder (AffineBinaryOpExpr expr) {
250
+ if (failed (walkPostOrder (expr.getLHS ())))
251
+ return failure ();
252
+ if (failed (walkPostOrder (expr.getRHS ())))
253
+ return failure ();
254
+ return success ();
189
255
}
190
256
};
191
257
@@ -246,7 +312,7 @@ class AffineExprVisitor {
246
312
// expressions are mapped to the same local identifier (same column position in
247
313
// 'localVarCst').
248
314
class SimpleAffineExprFlattener
249
- : public AffineExprVisitor<SimpleAffineExprFlattener> {
315
+ : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult > {
250
316
public:
251
317
// Flattend expression layout: [dims, symbols, locals, constant]
252
318
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,21 +341,21 @@ class SimpleAffineExprFlattener
275
341
virtual ~SimpleAffineExprFlattener () = default ;
276
342
277
343
// Visitor method overrides.
278
- void visitMulExpr (AffineBinaryOpExpr expr);
279
- void visitAddExpr (AffineBinaryOpExpr expr);
280
- void visitDimExpr (AffineDimExpr expr);
281
- void visitSymbolExpr (AffineSymbolExpr expr);
282
- void visitConstantExpr (AffineConstantExpr expr);
283
- void visitCeilDivExpr (AffineBinaryOpExpr expr);
284
- void visitFloorDivExpr (AffineBinaryOpExpr expr);
344
+ LogicalResult visitMulExpr (AffineBinaryOpExpr expr);
345
+ LogicalResult visitAddExpr (AffineBinaryOpExpr expr);
346
+ LogicalResult visitDimExpr (AffineDimExpr expr);
347
+ LogicalResult visitSymbolExpr (AffineSymbolExpr expr);
348
+ LogicalResult visitConstantExpr (AffineConstantExpr expr);
349
+ LogicalResult visitCeilDivExpr (AffineBinaryOpExpr expr);
350
+ LogicalResult visitFloorDivExpr (AffineBinaryOpExpr expr);
285
351
286
352
//
287
353
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
288
354
//
289
355
// A mod expression "expr mod c" is thus flattened by introducing a new local
290
356
// variable q (= expr floordiv c), such that expr mod c is replaced with
291
357
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
292
- void visitModExpr (AffineBinaryOpExpr expr);
358
+ LogicalResult visitModExpr (AffineBinaryOpExpr expr);
293
359
294
360
protected:
295
361
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +394,7 @@ class SimpleAffineExprFlattener
328
394
//
329
395
// A ceildiv is similarly flattened:
330
396
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
331
- void visitDivExpr (AffineBinaryOpExpr expr, bool isCeil);
397
+ LogicalResult visitDivExpr (AffineBinaryOpExpr expr, bool isCeil);
332
398
333
399
int findLocalId (AffineExpr localExpr);
334
400
0 commit comments