15
15
#include " mlir/IR/BuiltinTypes.h"
16
16
#include " mlir/IR/MLIRContext.h"
17
17
#include " mlir/IR/Matchers.h"
18
+ #include " mlir/IR/Operation.h"
18
19
#include " mlir/IR/PatternMatch.h"
19
20
#include " mlir/IR/TypeUtilities.h"
20
21
#include " mlir/Support/LogicalResult.h"
21
22
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23
#include " llvm/ADT/STLExtras.h"
23
24
#include " llvm/ADT/SmallVector.h"
24
- #include " llvm/ADT/TypeSwitch.h"
25
25
#include < cassert>
26
26
#include < cstdint>
27
27
@@ -100,11 +100,63 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
100
100
101
101
enum class ExtensionKind { Sign, Zero };
102
102
103
- ExtensionKind getExtensionKind (Operation *op) {
104
- assert (op);
105
- assert ((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && " Not an extension op" );
106
- return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
107
- }
103
+ // / Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
104
+ // / the exact op type. Exposes helper functions to query the types, operands,
105
+ // / and the result. This is so that we can handle both extension kinds without
106
+ // / needing to use templates or branching.
107
+ class ExtensionOp {
108
+ public:
109
+ // / Attemps to create a new extension op from `op`. Returns an extension op
110
+ // / wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
111
+ // / otherwise.
112
+ static FailureOr<ExtensionOp> from (Operation *op) {
113
+ if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
114
+ return ExtensionOp{op, ExtensionKind::Sign};
115
+ if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
116
+ return ExtensionOp{op, ExtensionKind::Zero};
117
+
118
+ return failure ();
119
+ }
120
+
121
+ ExtensionOp (const ExtensionOp &) = default ;
122
+ ExtensionOp &operator =(const ExtensionOp &) = default ;
123
+
124
+ // / Creates a new extension op of the same kind.
125
+ Operation *recreate (PatternRewriter &rewriter, Location loc, Type newType,
126
+ Value in) {
127
+ if (kind == ExtensionKind::Sign)
128
+ return rewriter.create <arith::ExtSIOp>(loc, newType, in);
129
+
130
+ return rewriter.create <arith::ExtUIOp>(loc, newType, in);
131
+ }
132
+
133
+ // / Replaces `toReplace` with a new extension op of the same kind.
134
+ void recreateAndReplace (PatternRewriter &rewriter, Operation *toReplace,
135
+ Value in) {
136
+ assert (toReplace->getNumResults () == 1 );
137
+ Type newType = toReplace->getResult (0 ).getType ();
138
+ Operation *newOp = recreate (rewriter, toReplace->getLoc (), newType, in);
139
+ rewriter.replaceOp (toReplace, newOp->getResult (0 ));
140
+ }
141
+
142
+ ExtensionKind getKind () { return kind; }
143
+
144
+ Value getResult () { return op->getResult (0 ); }
145
+ Value getIn () { return op->getOperand (0 ); }
146
+
147
+ Type getType () { return getResult ().getType (); }
148
+ Type getElementType () { return getElementTypeOrSelf (getType ()); }
149
+ Type getInType () { return getIn ().getType (); }
150
+ Type getInElementType () { return getElementTypeOrSelf (getInType ()); }
151
+
152
+ private:
153
+ ExtensionOp (Operation *op, ExtensionKind kind) : op(op), kind(kind) {
154
+ assert (op);
155
+ assert ((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && " Not an extension op" );
156
+ }
157
+ Operation *op = nullptr ;
158
+ ExtensionKind kind = {};
159
+ };
108
160
109
161
// / Returns the integer bitwidth required to represent `value`.
110
162
unsigned calculateBitsRequired (const APInt &value,
@@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
202
254
203
255
LogicalResult matchAndRewrite (vector::ExtractOp op,
204
256
PatternRewriter &rewriter) const override {
205
- Operation *def = op.getVector ().getDefiningOp ();
206
- if (!def)
257
+ FailureOr<ExtensionOp> ext =
258
+ ExtensionOp::from (op.getVector ().getDefiningOp ());
259
+ if (failed (ext))
207
260
return failure ();
208
261
209
- return TypeSwitch<Operation *, LogicalResult>(def)
210
- .Case <arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
211
- Value newExtract = rewriter.create <vector::ExtractOp>(
212
- op.getLoc (), extOp.getIn (), op.getPosition ());
213
- rewriter.replaceOpWithNewOp <decltype (extOp)>(op, op.getType (),
214
- newExtract);
215
- return success ();
216
- })
217
- .Default (failure ());
262
+ Value newExtract = rewriter.create <vector::ExtractOp>(
263
+ op.getLoc (), ext->getIn (), op.getPosition ());
264
+ ext->recreateAndReplace (rewriter, op, newExtract);
265
+ return success ();
218
266
}
219
267
};
220
268
@@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final
224
272
225
273
LogicalResult matchAndRewrite (vector::ExtractElementOp op,
226
274
PatternRewriter &rewriter) const override {
227
- Operation *def = op.getVector ().getDefiningOp ();
228
- if (!def)
275
+ FailureOr<ExtensionOp> ext =
276
+ ExtensionOp::from (op.getVector ().getDefiningOp ());
277
+ if (failed (ext))
229
278
return failure ();
230
279
231
- return TypeSwitch<Operation *, LogicalResult>(def)
232
- .Case <arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
233
- Value newExtract = rewriter.create <vector::ExtractElementOp>(
234
- op.getLoc (), extOp.getIn (), op.getPosition ());
235
- rewriter.replaceOpWithNewOp <decltype (extOp)>(op, op.getType (),
236
- newExtract);
237
- return success ();
238
- })
239
- .Default (failure ());
280
+ Value newExtract = rewriter.create <vector::ExtractElementOp>(
281
+ op.getLoc (), ext->getIn (), op.getPosition ());
282
+ ext->recreateAndReplace (rewriter, op, newExtract);
283
+ return success ();
240
284
}
241
285
};
242
286
@@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final
246
290
247
291
LogicalResult matchAndRewrite (vector::ExtractStridedSliceOp op,
248
292
PatternRewriter &rewriter) const override {
249
- Operation *def = op.getVector ().getDefiningOp ();
250
- if (!def)
293
+ FailureOr<ExtensionOp> ext =
294
+ ExtensionOp::from (op.getVector ().getDefiningOp ());
295
+ if (failed (ext))
251
296
return failure ();
252
297
253
- return TypeSwitch<Operation *, LogicalResult>(def)
254
- .Case <arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
255
- VectorType origTy = op.getType ();
256
- Type inElemTy =
257
- cast<VectorType>(extOp.getIn ().getType ()).getElementType ();
258
- VectorType extractTy = origTy.cloneWith (origTy.getShape (), inElemTy);
259
- Value newExtract = rewriter.create <vector::ExtractStridedSliceOp>(
260
- op.getLoc (), extractTy, extOp.getIn (), op.getOffsets (),
261
- op.getSizes (), op.getStrides ());
262
- rewriter.replaceOpWithNewOp <decltype (extOp)>(op, op.getType (),
263
- newExtract);
264
- return success ();
265
- })
266
- .Default (failure ());
298
+ VectorType origTy = op.getType ();
299
+ VectorType extractTy =
300
+ origTy.cloneWith (origTy.getShape (), ext->getInElementType ());
301
+ Value newExtract = rewriter.create <vector::ExtractStridedSliceOp>(
302
+ op.getLoc (), extractTy, ext->getIn (), op.getOffsets (), op.getSizes (),
303
+ op.getStrides ());
304
+ ext->recreateAndReplace (rewriter, op, newExtract);
305
+ return success ();
267
306
}
268
307
};
269
308
@@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
272
311
273
312
LogicalResult matchAndRewrite (vector::InsertOp op,
274
313
PatternRewriter &rewriter) const override {
275
- Operation *def = op.getSource ().getDefiningOp ();
276
- if (!def)
314
+ FailureOr<ExtensionOp> ext =
315
+ ExtensionOp::from (op.getSource ().getDefiningOp ());
316
+ if (failed (ext))
277
317
return failure ();
278
318
279
- return TypeSwitch<Operation *, LogicalResult>(def)
280
- .Case <arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
281
- // Rewrite the insertion in terms of narrower operands
282
- // and later extend the result to the original bitwidth.
283
- FailureOr<vector::InsertOp> newInsert =
284
- createNarrowInsert (op, rewriter, extOp);
285
- if (failed (newInsert))
286
- return failure ();
287
- rewriter.replaceOpWithNewOp <decltype (extOp)>(op, op.getType (),
288
- *newInsert);
289
- return success ();
290
- })
291
- .Default (failure ());
319
+ FailureOr<vector::InsertOp> newInsert =
320
+ createNarrowInsert (op, rewriter, *ext);
321
+ if (failed (newInsert))
322
+ return failure ();
323
+ ext->recreateAndReplace (rewriter, op, *newInsert);
324
+ return success ();
292
325
}
293
326
294
327
FailureOr<vector::InsertOp> createNarrowInsert (vector::InsertOp op,
295
328
PatternRewriter &rewriter,
296
- Operation *insValue) const {
297
- assert ((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
298
-
329
+ ExtensionOp insValue) const {
299
330
// Calculate the operand and result bitwidths. We can only apply narrowing
300
331
// when the inserted source value and destination vector require fewer bits
301
332
// than the result. Because the source and destination may have different
@@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
306
337
if (failed (origBitsRequired))
307
338
return failure ();
308
339
309
- ExtensionKind kind = getExtensionKind (insValue);
310
340
FailureOr<unsigned > destBitsRequired =
311
- calculateBitsRequired (op.getDest (), kind );
341
+ calculateBitsRequired (op.getDest (), insValue. getKind () );
312
342
if (failed (destBitsRequired) || *destBitsRequired >= *origBitsRequired)
313
343
return failure ();
314
344
315
345
FailureOr<unsigned > insertedBitsRequired =
316
- calculateBitsRequired (insValue-> getOperands (). front (), kind );
346
+ calculateBitsRequired (insValue. getIn (), insValue. getKind () );
317
347
if (failed (insertedBitsRequired) ||
318
348
*insertedBitsRequired >= *origBitsRequired)
319
349
return failure ();
@@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
327
357
return failure ();
328
358
329
359
FailureOr<Type> newInsertedValueTy =
330
- getNarrowType (newInsertionBits, insValue-> getResultTypes (). front ());
360
+ getNarrowType (newInsertionBits, insValue. getType ());
331
361
if (failed (newInsertedValueTy))
332
362
return failure ();
333
363
334
364
Location loc = op.getLoc ();
335
365
Value narrowValue = rewriter.createOrFold <arith::TruncIOp>(
336
- loc, *newInsertedValueTy, insValue-> getResult (0 ));
366
+ loc, *newInsertedValueTy, insValue. getResult ());
337
367
Value narrowDest =
338
368
rewriter.createOrFold <arith::TruncIOp>(loc, *newVecTy, op.getDest ());
339
369
return rewriter.create <vector::InsertOp>(loc, narrowValue, narrowDest,
0 commit comments