@@ -71,18 +71,17 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
71
71
return success ();
72
72
}
73
73
74
- // / Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp` .
74
+ // / Specialization for `linalg::GenericOp`.
75
75
// / A pattern to convert Generic Linalg operations which work on tensors to
76
76
// / use buffers. BufferPlacement pass should be later used to move
77
77
// / Alloc operations to the correct positions and insert the missing Dealloc
78
78
// / operations in the correct places.
79
- template <typename GenericOpTy>
80
79
static void
81
80
finalizeBufferAllocationForGenericOp (ConversionPatternRewriter &rewriter,
82
- GenericOpTy genericOp, ValueRange inputs,
81
+ GenericOp genericOp, ValueRange inputs,
83
82
ValueRange outputs) {
84
83
// Generate a new linalg operation that works on buffers.
85
- auto newGenericOp = rewriter.create <GenericOpTy >(
84
+ auto newGenericOp = rewriter.create <GenericOp >(
86
85
genericOp.getLoc (),
87
86
/* resultTensorTypes=*/ llvm::None,
88
87
/* inputs=*/ inputs,
@@ -116,7 +115,6 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
116
115
linalg::LinalgOp linalgOp,
117
116
ValueRange inputs, ValueRange outputs) {
118
117
assert (!isa<linalg::GenericOp>(linalgOp.getOperation ()));
119
- assert (!isa<linalg::IndexedGenericOp>(linalgOp.getOperation ()));
120
118
SmallVector<Value, 8 > newOperands = inputs;
121
119
newOperands.append (outputs.begin (), outputs.end ());
122
120
auto otherOperands = linalgOp.getAssumedNonShapedOperands ();
@@ -195,6 +193,10 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
195
193
LogicalResult
196
194
matchAndRewrite (LinalgOp op, ArrayRef<Value> operands,
197
195
ConversionPatternRewriter &rewriter) const final {
196
+ // Canonicalize indexed generic operations before bufferization.
197
+ if (isa<IndexedGenericOp>(op))
198
+ return failure ();
199
+
198
200
// GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
199
201
if (!op->hasAttr (" operand_segment_sizes" ))
200
202
return failure ();
@@ -215,15 +217,8 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
215
217
216
218
// Delegate to the linalg generic pattern.
217
219
if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
218
- finalizeBufferAllocationForGenericOp<GenericOp>(
219
- rewriter, genericOp, adaptor.inputs (), newOutputBuffers);
220
- return success ();
221
- }
222
-
223
- // Delegate to the linalg indexed generic pattern.
224
- if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(*op)) {
225
- finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
226
- rewriter, genericOp, adaptor.inputs (), newOutputBuffers);
220
+ finalizeBufferAllocationForGenericOp (rewriter, genericOp,
221
+ adaptor.inputs (), newOutputBuffers);
227
222
return success ();
228
223
}
229
224
0 commit comments