Skip to content

Commit 49df068

Browse files
committed
[mlir][arith][NFC] Simplify narrowing patterns with a wrapper type
Add a new wraper type that represents either of `ExtSIOp` or `ExtUIOp`. This is to simplify the code by using a single type, so that we do not have to use templates or branching to handle both extension kinds. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149485
1 parent f762798 commit 49df068

File tree

1 file changed

+97
-67
lines changed

1 file changed

+97
-67
lines changed

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 97 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "mlir/IR/MLIRContext.h"
1717
#include "mlir/IR/Matchers.h"
18+
#include "mlir/IR/Operation.h"
1819
#include "mlir/IR/PatternMatch.h"
1920
#include "mlir/IR/TypeUtilities.h"
2021
#include "mlir/Support/LogicalResult.h"
2122
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/SmallVector.h"
24-
#include "llvm/ADT/TypeSwitch.h"
2525
#include <cassert>
2626
#include <cstdint>
2727

@@ -100,11 +100,63 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
100100

101101
enum class ExtensionKind { Sign, Zero };
102102

103-
ExtensionKind getExtensionKind(Operation *op) {
104-
assert(op);
105-
assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
106-
return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
107-
}
103+
/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
104+
/// the exact op type. Exposes helper functions to query the types, operands,
105+
/// and the result. This is so that we can handle both extension kinds without
106+
/// needing to use templates or branching.
107+
class ExtensionOp {
108+
public:
109+
/// Attemps to create a new extension op from `op`. Returns an extension op
110+
/// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
111+
/// otherwise.
112+
static FailureOr<ExtensionOp> from(Operation *op) {
113+
if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
114+
return ExtensionOp{op, ExtensionKind::Sign};
115+
if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
116+
return ExtensionOp{op, ExtensionKind::Zero};
117+
118+
return failure();
119+
}
120+
121+
ExtensionOp(const ExtensionOp &) = default;
122+
ExtensionOp &operator=(const ExtensionOp &) = default;
123+
124+
/// Creates a new extension op of the same kind.
125+
Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
126+
Value in) {
127+
if (kind == ExtensionKind::Sign)
128+
return rewriter.create<arith::ExtSIOp>(loc, newType, in);
129+
130+
return rewriter.create<arith::ExtUIOp>(loc, newType, in);
131+
}
132+
133+
/// Replaces `toReplace` with a new extension op of the same kind.
134+
void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
135+
Value in) {
136+
assert(toReplace->getNumResults() == 1);
137+
Type newType = toReplace->getResult(0).getType();
138+
Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
139+
rewriter.replaceOp(toReplace, newOp->getResult(0));
140+
}
141+
142+
ExtensionKind getKind() { return kind; }
143+
144+
Value getResult() { return op->getResult(0); }
145+
Value getIn() { return op->getOperand(0); }
146+
147+
Type getType() { return getResult().getType(); }
148+
Type getElementType() { return getElementTypeOrSelf(getType()); }
149+
Type getInType() { return getIn().getType(); }
150+
Type getInElementType() { return getElementTypeOrSelf(getInType()); }
151+
152+
private:
153+
ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
154+
assert(op);
155+
assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
156+
}
157+
Operation *op = nullptr;
158+
ExtensionKind kind = {};
159+
};
108160

109161
/// Returns the integer bitwidth required to represent `value`.
110162
unsigned calculateBitsRequired(const APInt &value,
@@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
202254

203255
LogicalResult matchAndRewrite(vector::ExtractOp op,
204256
PatternRewriter &rewriter) const override {
205-
Operation *def = op.getVector().getDefiningOp();
206-
if (!def)
257+
FailureOr<ExtensionOp> ext =
258+
ExtensionOp::from(op.getVector().getDefiningOp());
259+
if (failed(ext))
207260
return failure();
208261

209-
return TypeSwitch<Operation *, LogicalResult>(def)
210-
.Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
211-
Value newExtract = rewriter.create<vector::ExtractOp>(
212-
op.getLoc(), extOp.getIn(), op.getPosition());
213-
rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
214-
newExtract);
215-
return success();
216-
})
217-
.Default(failure());
262+
Value newExtract = rewriter.create<vector::ExtractOp>(
263+
op.getLoc(), ext->getIn(), op.getPosition());
264+
ext->recreateAndReplace(rewriter, op, newExtract);
265+
return success();
218266
}
219267
};
220268

@@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final
224272

225273
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
226274
PatternRewriter &rewriter) const override {
227-
Operation *def = op.getVector().getDefiningOp();
228-
if (!def)
275+
FailureOr<ExtensionOp> ext =
276+
ExtensionOp::from(op.getVector().getDefiningOp());
277+
if (failed(ext))
229278
return failure();
230279

231-
return TypeSwitch<Operation *, LogicalResult>(def)
232-
.Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
233-
Value newExtract = rewriter.create<vector::ExtractElementOp>(
234-
op.getLoc(), extOp.getIn(), op.getPosition());
235-
rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
236-
newExtract);
237-
return success();
238-
})
239-
.Default(failure());
280+
Value newExtract = rewriter.create<vector::ExtractElementOp>(
281+
op.getLoc(), ext->getIn(), op.getPosition());
282+
ext->recreateAndReplace(rewriter, op, newExtract);
283+
return success();
240284
}
241285
};
242286

@@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final
246290

247291
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
248292
PatternRewriter &rewriter) const override {
249-
Operation *def = op.getVector().getDefiningOp();
250-
if (!def)
293+
FailureOr<ExtensionOp> ext =
294+
ExtensionOp::from(op.getVector().getDefiningOp());
295+
if (failed(ext))
251296
return failure();
252297

253-
return TypeSwitch<Operation *, LogicalResult>(def)
254-
.Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
255-
VectorType origTy = op.getType();
256-
Type inElemTy =
257-
cast<VectorType>(extOp.getIn().getType()).getElementType();
258-
VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
259-
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
260-
op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
261-
op.getSizes(), op.getStrides());
262-
rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
263-
newExtract);
264-
return success();
265-
})
266-
.Default(failure());
298+
VectorType origTy = op.getType();
299+
VectorType extractTy =
300+
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
301+
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
302+
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
303+
op.getStrides());
304+
ext->recreateAndReplace(rewriter, op, newExtract);
305+
return success();
267306
}
268307
};
269308

@@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
272311

273312
LogicalResult matchAndRewrite(vector::InsertOp op,
274313
PatternRewriter &rewriter) const override {
275-
Operation *def = op.getSource().getDefiningOp();
276-
if (!def)
314+
FailureOr<ExtensionOp> ext =
315+
ExtensionOp::from(op.getSource().getDefiningOp());
316+
if (failed(ext))
277317
return failure();
278318

279-
return TypeSwitch<Operation *, LogicalResult>(def)
280-
.Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
281-
// Rewrite the insertion in terms of narrower operands
282-
// and later extend the result to the original bitwidth.
283-
FailureOr<vector::InsertOp> newInsert =
284-
createNarrowInsert(op, rewriter, extOp);
285-
if (failed(newInsert))
286-
return failure();
287-
rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
288-
*newInsert);
289-
return success();
290-
})
291-
.Default(failure());
319+
FailureOr<vector::InsertOp> newInsert =
320+
createNarrowInsert(op, rewriter, *ext);
321+
if (failed(newInsert))
322+
return failure();
323+
ext->recreateAndReplace(rewriter, op, *newInsert);
324+
return success();
292325
}
293326

294327
FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
295328
PatternRewriter &rewriter,
296-
Operation *insValue) const {
297-
assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
298-
329+
ExtensionOp insValue) const {
299330
// Calculate the operand and result bitwidths. We can only apply narrowing
300331
// when the inserted source value and destination vector require fewer bits
301332
// than the result. Because the source and destination may have different
@@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
306337
if (failed(origBitsRequired))
307338
return failure();
308339

309-
ExtensionKind kind = getExtensionKind(insValue);
310340
FailureOr<unsigned> destBitsRequired =
311-
calculateBitsRequired(op.getDest(), kind);
341+
calculateBitsRequired(op.getDest(), insValue.getKind());
312342
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
313343
return failure();
314344

315345
FailureOr<unsigned> insertedBitsRequired =
316-
calculateBitsRequired(insValue->getOperands().front(), kind);
346+
calculateBitsRequired(insValue.getIn(), insValue.getKind());
317347
if (failed(insertedBitsRequired) ||
318348
*insertedBitsRequired >= *origBitsRequired)
319349
return failure();
@@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
327357
return failure();
328358

329359
FailureOr<Type> newInsertedValueTy =
330-
getNarrowType(newInsertionBits, insValue->getResultTypes().front());
360+
getNarrowType(newInsertionBits, insValue.getType());
331361
if (failed(newInsertedValueTy))
332362
return failure();
333363

334364
Location loc = op.getLoc();
335365
Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
336-
loc, *newInsertedValueTy, insValue->getResult(0));
366+
loc, *newInsertedValueTy, insValue.getResult());
337367
Value narrowDest =
338368
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
339369
return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,

0 commit comments

Comments
 (0)