@@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
112
112
namespace {
113
113
114
114
// ===----------------------------------------------------------------------===//
115
- // ConvertMemRefAlloc
115
+ // ConvertMemRefAllocation
116
116
// ===----------------------------------------------------------------------===//
117
117
118
- struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
119
- using OpConversionPattern::OpConversionPattern;
118
+ template <typename OpTy>
119
+ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
120
+ using OpConversionPattern<OpTy>::OpConversionPattern;
120
121
121
122
LogicalResult
122
- matchAndRewrite (memref::AllocOp op, OpAdaptor adaptor,
123
+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
123
124
ConversionPatternRewriter &rewriter) const override {
124
- auto currentType = op.getMemref ().getType ().cast <MemRefType>();
125
- auto newResultType =
126
- getTypeConverter ()->convertType (op.getType ()).dyn_cast <MemRefType>();
125
+ static_assert (std::is_same<OpTy, memref::AllocOp>() ||
126
+ std::is_same<OpTy, memref::AllocaOp>(),
127
+ " expected only memref::AllocOp or memref::AllocaOp" );
128
+ auto currentType = cast<MemRefType>(op.getMemref ().getType ());
129
+ auto newResultType = dyn_cast<MemRefType>(
130
+ this ->getTypeConverter ()->convertType (op.getType ()));
127
131
if (!newResultType) {
128
132
return rewriter.notifyMatchFailure (
129
133
op->getLoc (),
@@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
132
136
133
137
// Special case zero-rank memrefs.
134
138
if (currentType.getRank () == 0 ) {
135
- rewriter.replaceOpWithNewOp <memref::AllocOp>(
136
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands (),
137
- adaptor.getAlignmentAttr ());
139
+ rewriter.replaceOpWithNewOp <OpTy>(op, newResultType, ValueRange{},
140
+ adaptor.getSymbolOperands (),
141
+ adaptor.getAlignmentAttr ());
138
142
return success ();
139
143
}
140
144
@@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
156
160
rewriter, loc, linearizedMemRefInfo.linearizedSize ));
157
161
}
158
162
159
- rewriter.replaceOpWithNewOp <memref::AllocOp>(
160
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands (),
161
- adaptor.getAlignmentAttr ());
163
+ rewriter.replaceOpWithNewOp <OpTy>(op, newResultType, dynamicLinearizedSize,
164
+ adaptor.getSymbolOperands (),
165
+ adaptor.getAlignmentAttr ());
162
166
return success ();
163
167
}
164
168
};
@@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
344
348
RewritePatternSet &patterns) {
345
349
346
350
// Populate `memref.*` conversion patterns.
347
- patterns
348
- .add <ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
349
- ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
350
- typeConverter, patterns.getContext ());
351
+ patterns.add <ConvertMemRefAllocation<memref::AllocOp>,
352
+ ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
353
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
354
+ ConvertMemRefReinterpretCast>(typeConverter,
355
+ patterns.getContext ());
351
356
memref::populateResolveExtractStridedMetadataPatterns (patterns);
352
357
}
353
358
0 commit comments