Skip to content

Commit 0a8e3dd

Browse files
[mlir][Interfaces] DestinationStyleOpInterface: Rename hasTensor/BufferSemantics (#77574)
Rename interface functions as follows: * `hasTensorSemantics` -> `hasPureTensorSemantics` * `hasBufferSemantics` -> `hasPureBufferSemantics` These two functions return "true" if the op has tensor/buffer operands but not buffer/tensor operands. Also drop the "ranked" part from the interface, i.e., do not distinguish between ranked/unranked types. The new function names describe the functions more accurately. They also align their semantics with the notion of "tensor semantics" with the bufferization framework. (An op is supposed to be bufferized if it has tensor operands, and we don't care if it also has memref operands.) This change is in preparation of #75273, which adds `BufferizableOpInterface::hasTensorSemantics`. By renaming the functions in the `DestinationStyleOpInterface`, we can avoid name clashes between the two interfaces.
1 parent 1aacdfe commit 0a8e3dd

25 files changed

+85
-83
lines changed

mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,24 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
1717
as initial tensor values for the results of the operation or the init
1818
buffers to which the results of the op will be written.
1919

20-
Init operands must be ranked tensors or ranked memrefs. Input operands can
21-
have any type. All non-init operands are DPS inputs.
20+
Init operands must be tensors or memrefs. Input operands can have any type.
21+
All non-init operands are DPS inputs.
2222

2323
The init operands of this op are specified by the MutableOperandRange that
2424
the `getDpsInitsMutable` interface methods returns. This implies that the
2525
init operands must be a consecutive range of operands.
2626

27-
If the op has "tensor semantics", then the input operands are either ranked
28-
tensors or other non-tensor/memref types ("scalars"). The init operands are
29-
ranked tensors and every tensor init is tied to a corresponding tensor
30-
OpResult in a 1-to-1 fashion. The i-th init tensor is tied to the i-th
31-
OpResult. The op may not have any additional OpResults. Init operands and
32-
their tied OpResults have the same type. Dynamic dimension sizes also match
33-
at runtime.
27+
Each tensor init operand is tied to a corresponding tensor OpResult in a
28+
1-to-1 fashion. The i-th init tensor is tied to the i-th OpResult. The op
29+
may not have any additional OpResults. Init operands and their tied
30+
OpResults have the same type. Dynamic dimension sizes also match at runtime.
3431

35-
If the op has "buffer semantics", then the input operands are either ranked
36-
memrefs or other non-tensor/memref types ("scalar" types). Furthermore, the
37-
init operands are ranked memrefs and the op has no results.
32+
Note: This implies that a destination style op without any tensor inits must
33+
not have any OpResults.
34+
35+
An op has "pure tensor semantics" if it has at least one tensor operand and
36+
no buffer (memref) operands. It has "pure buffer semantics" if it has at
37+
least one buffer (memref) operand and no tensor operands.
3838

3939
Destination-passing style abstraction makes certain transformations easier.
4040
For example, tiling implementation can extract/insert slices from/into the
@@ -148,7 +148,8 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
148148
/// neither a MemRef nor a tensor value.
149149
bool isScalar(::mlir::OpOperand *opOperand) {
150150
assert(opOperand->getOwner() == $_op && "invalid operand");
151-
return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
151+
return !::llvm::isa<BaseMemRefType, TensorType>(
152+
opOperand->get().getType());
152153
}
153154

154155
/// Return the OpResult that is tied to the given OpOperand.
@@ -169,37 +170,36 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
169170
return $_op.getDpsInitOperand(opResult.getResultNumber());
170171
}
171172

172-
/// Return whether the op has buffer semantics. That is the case if the op
173-
/// has no ranked tensor operands and at least one memref operand.
174-
bool hasBufferSemantics() {
173+
/// Return whether the op has pure buffer semantics. That is the case if the
174+
/// op has no tensor operands and at least one memref operand.
175+
bool hasPureBufferSemantics() {
175176
// No tensors.
176177
auto isTensor = [](Value v){
177-
return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
178+
return ::llvm::isa<::mlir::TensorType>(v.getType());
178179
};
179180
if (::llvm::any_of($_op->getOperands(), isTensor))
180181
return false;
181182
// At least one memref.
182183
auto isMemref = [](Value v){
183-
return ::llvm::isa<::mlir::MemRefType>(v.getType());
184+
return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
184185
};
185186
return llvm::any_of($_op->getOperands(), isMemref);
186187
}
187188

188-
/// Return whether the op has tensor semantics. That is the case if the op
189-
/// has no memref operands and at least one ranked tensor operand.
190-
bool hasTensorSemantics() {
189+
/// Return whether the op has pure tensor semantics. That is the case if the
190+
/// op has no memref operands and at least one tensor operand.
191+
bool hasPureTensorSemantics() {
191192
// No memrefs.
192193
auto isMemref = [](Value v){
193-
return ::llvm::isa<::mlir::MemRefType>(v.getType());
194+
return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
194195
};
195196
if (::llvm::any_of($_op->getOperands(), isMemref))
196197
return false;
197198
// At least one tensor.
198199
auto isTensor = [](Value v){
199-
return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
200+
return ::llvm::isa<::mlir::TensorType>(v.getType());
200201
};
201-
return llvm::any_of($_op->getOperands(), isTensor);
202-
}
202+
return llvm::any_of($_op->getOperands(), isTensor); }
203203
}];
204204

205205
let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ struct EraseSelfCopy : OpRewritePattern<CopyOp> {
550550
PatternRewriter &rewriter) const override {
551551
if (copyOp.getInputs() != copyOp.getOutputs())
552552
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
553-
if (copyOp.hasBufferSemantics())
553+
if (copyOp.hasPureBufferSemantics())
554554
rewriter.eraseOp(copyOp);
555555
else
556556
rewriter.replaceOp(copyOp, copyOp.getInputs());
@@ -1112,7 +1112,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11121112
return failure();
11131113

11141114
// In the buffer case, we need to check exact buffer equality.
1115-
if (genericOp.hasBufferSemantics()) {
1115+
if (genericOp.hasPureBufferSemantics()) {
11161116
if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
11171117
genericOp.getDpsInputOperand(0)->get() ==
11181118
genericOp.getDpsInitOperand(0)->get()) {
@@ -1123,7 +1123,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11231123
}
11241124

11251125
// Mixed semantics is not supported yet.
1126-
if (!genericOp.hasTensorSemantics())
1126+
if (!genericOp.hasPureTensorSemantics())
11271127
return failure();
11281128

11291129
// Get the argument number of the returned values. That is the operand
@@ -2257,7 +2257,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
22572257

22582258
LogicalResult matchAndRewrite(LinalgOp linalgOp,
22592259
PatternRewriter &rewriter) const override {
2260-
if (!linalgOp.hasTensorSemantics())
2260+
if (!linalgOp.hasPureTensorSemantics())
22612261
return failure();
22622262

22632263
// Maps must be projected permutations.
@@ -2376,7 +2376,7 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
23762376
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
23772377

23782378
SmallVector<Type, 4> resultTypes;
2379-
if (hasTensorSemantics())
2379+
if (hasPureTensorSemantics())
23802380
resultTypes.push_back(tiledOperands[1].getType());
23812381
Operation *tiledOp =
23822382
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct BubbleUpExtractSliceOpPattern
6868
"expected single output of linalg op");
6969
}
7070

71-
if (!linalgOp.hasTensorSemantics()) {
71+
if (!linalgOp.hasPureTensorSemantics()) {
7272
return rewriter.notifyMatchFailure(sliceOp,
7373
"expected tensor of linalg op");
7474
}

mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
3232
rewriter.setInsertionPoint(op);
3333

3434
// Nothing to do. This op is already bufferized.
35-
if (op.hasBufferSemantics())
35+
if (op.hasPureBufferSemantics())
3636
return success();
3737

3838
// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
3939
// basis.
40-
if (!op.hasTensorSemantics())
41-
return op->emitError() << "op does not have tensor semantics";
40+
if (!op.hasPureTensorSemantics())
41+
return op->emitError() << "op does not have pure tensor semantics";
4242

4343
// New input operands for the cloned op.
4444
SmallVector<Value> newInputBuffers;

mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
5757
LogicalResult matchAndRewrite(GenericOp genericOp,
5858
PatternRewriter &rewriter) const override {
5959
// Mixed and buffer sematics aren't supported.
60-
if (!genericOp.hasTensorSemantics())
60+
if (!genericOp.hasPureTensorSemantics())
6161
return failure();
6262

6363
// Only support ops generating one output for now.

mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
258258
// TODO: this could be generalized to handle `linalg.generic` with buffer
259259
// operands too but requires allocation for intermediates. Punt on this for
260260
// now.
261-
if (!genericOp.hasTensorSemantics()) {
261+
if (!genericOp.hasPureTensorSemantics()) {
262262
return rewriter.notifyMatchFailure(
263263
genericOp, "only operations with tensor semantics are handled");
264264
}

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
8383
using OpRewritePattern<GenericOp>::OpRewritePattern;
8484
LogicalResult matchAndRewrite(GenericOp genericOp,
8585
PatternRewriter &rewriter) const override {
86-
if (!genericOp.hasTensorSemantics())
86+
if (!genericOp.hasPureTensorSemantics())
8787
return failure();
8888
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
8989
return failure();

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
105105
// Consumer can have mixed semantics, just check operand itself has tensor
106106
// type. Producer must have full tensor semantics to avoid potential
107107
// aliasing between producer and consumer memrefs.
108-
if (!producer.hasTensorSemantics() ||
108+
if (!producer.hasPureTensorSemantics() ||
109109
!isa<RankedTensorType>(fusedOperand->get().getType()))
110110
return false;
111111

@@ -530,7 +530,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
530530
// permutations.
531531
// - The fused tensor is not a scalar.
532532
// - All the loops are parallel loops.
533-
return genericOp.hasTensorSemantics() &&
533+
return genericOp.hasPureTensorSemantics() &&
534534
llvm::all_of(genericOp.getIndexingMaps().getValue(),
535535
[](Attribute attr) {
536536
return cast<AffineMapAttr>(attr)
@@ -1124,7 +1124,7 @@ static SmallVector<ReassociationIndices>
11241124
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
11251125
ArrayRef<ReassociationIndices> reassociation) {
11261126
// Some basic checks for this fusion to be valid.
1127-
if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
1127+
if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
11281128
return {};
11291129

11301130
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
@@ -1476,7 +1476,7 @@ Operation *createCollapsedOp(LinalgType op,
14761476
outputOperands.push_back(newOutput);
14771477
// If the op has "buffer semantics", then the init operands are ranked
14781478
// memrefs and the op has no results.
1479-
if (!op.hasBufferSemantics())
1479+
if (!op.hasPureBufferSemantics())
14801480
resultTypes.push_back(newOutput.getType());
14811481
}
14821482

@@ -1521,8 +1521,8 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15211521
}))
15221522
return failure();
15231523

1524-
bool hasBufferSemantics = op.hasBufferSemantics();
1525-
if (hasBufferSemantics &&
1524+
bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1525+
if (hasPureBufferSemantics &&
15261526
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
15271527
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
15281528
if (!memRefToCollapse)
@@ -1705,7 +1705,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
17051705

17061706
LogicalResult matchAndRewrite(GenericOp genericOp,
17071707
PatternRewriter &rewriter) const override {
1708-
if (!genericOp.hasTensorSemantics())
1708+
if (!genericOp.hasPureTensorSemantics())
17091709
return failure();
17101710
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
17111711
Operation *def = opOperand->get().getDefiningOp();
@@ -1857,7 +1857,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
18571857

18581858
LogicalResult matchAndRewrite(GenericOp genericOp,
18591859
PatternRewriter &rewriter) const override {
1860-
if (!genericOp.hasTensorSemantics())
1860+
if (!genericOp.hasPureTensorSemantics())
18611861
return failure();
18621862
bool fillFound = false;
18631863
Block &payload = genericOp.getRegion().front();

mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
183183
dedupedOutpts;
184184
// If the op doesn't have tensor semantics or outputs should not be removed,
185185
// keep all the outputs as preserved.
186-
if (!genericOp.hasTensorSemantics() || !removeOutputs) {
186+
if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
187187
for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
188188
origToNewPos[en.index()] = newOutputOperands.size();
189189
newOutputOperands.push_back(en.value().get());
@@ -317,7 +317,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
317317
PatternRewriter &rewriter) const override {
318318

319319
// If the op doesnt have tensor semantics, preserve the outputs as is.
320-
if (!genericOp.hasTensorSemantics())
320+
if (!genericOp.hasPureTensorSemantics())
321321
return failure();
322322

323323
bool hasRemovedCycles = false;

mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
5959
ValueRange outputs = linalgOp.getDpsInits();
6060
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
6161
SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
62-
SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
62+
SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
6363
? TypeRange(ValueRange(outputs))
6464
: TypeRange{};
6565

mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
3535
using OpRewritePattern<GenericOp>::OpRewritePattern;
3636
LogicalResult matchAndRewrite(GenericOp genericOp,
3737
PatternRewriter &rewriter) const override {
38-
if (!genericOp.hasTensorSemantics())
38+
if (!genericOp.hasPureTensorSemantics())
3939
return failure();
4040

4141
SmallVector<size_t> scalarOperands;

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ template <typename LoadOpTy, typename StoreOpTy>
128128
static void emitScalarImplementation(OpBuilder &b, Location loc,
129129
ArrayRef<Value> allIvs,
130130
LinalgOp linalgOp) {
131-
assert(linalgOp.hasBufferSemantics() &&
131+
assert(linalgOp.hasPureBufferSemantics() &&
132132
"expected linalg op with buffer semantics");
133133
SmallVector<Value> indexedValues;
134134
indexedValues.reserve(linalgOp->getNumOperands());
@@ -218,7 +218,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
218218

219219
// The flattened loopToOperandRangesMaps is expected to be an invertible
220220
// permutation map (which is asserted in the inverse calculation).
221-
assert(linalgOp.hasBufferSemantics() &&
221+
assert(linalgOp.hasPureBufferSemantics() &&
222222
"expected linalg op with buffer semantics");
223223

224224
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
@@ -264,7 +264,7 @@ class LinalgRewritePattern : public RewritePattern {
264264
LogicalResult matchAndRewrite(Operation *op,
265265
PatternRewriter &rewriter) const override {
266266
auto linalgOp = dyn_cast<LinalgOp>(op);
267-
if (!isa<LinalgOp>(op) || !linalgOp.hasBufferSemantics()) {
267+
if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
268268
return rewriter.notifyMatchFailure(
269269
op, "expected linalg op with buffer semantics");
270270
}

mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
3939
Location loc = operation->getLoc();
4040
auto linalgOp = dyn_cast<LinalgOp>(operation);
4141
// Exit out on the memref version of this operation.
42-
if (!linalgOp || !linalgOp.hasTensorSemantics())
42+
if (!linalgOp || !linalgOp.hasPureTensorSemantics())
4343
return failure();
4444

4545
auto result = operation->getResult(0);

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
168168
}
169169

170170
// TODO: there are cases where we may still want to pad to larger sizes.
171-
if (!opToPad.hasTensorSemantics())
171+
if (!opToPad.hasPureTensorSemantics())
172172
return rewriter.notifyMatchFailure(opToPad,
173173
"expected operation on tensors");
174174

@@ -265,7 +265,7 @@ mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
265265
assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
266266
"invalid options");
267267

268-
if (!linalgOp.hasTensorSemantics())
268+
if (!linalgOp.hasPureTensorSemantics())
269269
return rewriter.notifyMatchFailure(
270270
linalgOp, "only applies to Linalg ops with tensor semantics");
271271

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ struct LinalgOpInstancePromotionOptions {
164164
LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
165165
LinalgOp linalgOp, const LinalgPromotionOptions &options)
166166
: subViews(), alignment(options.alignment) {
167-
assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
167+
assert(linalgOp.hasPureBufferSemantics() &&
168+
"revisit usage of shaped operand");
168169
auto vUseFullTileBuffers =
169170
options.useFullTileBuffers.value_or(llvm::SmallBitVector());
170171
vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
@@ -346,7 +347,8 @@ promoteSubViews(ImplicitLocOpBuilder &b,
346347
static FailureOr<LinalgOp>
347348
promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
348349
LinalgOpInstancePromotionOptions options, DataLayout &layout) {
349-
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
350+
assert(op.hasPureBufferSemantics() &&
351+
"expected linalg op with buffer semantics");
350352

351353
// 1. Promote the specified views and use them in the new op.
352354
auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
@@ -400,7 +402,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
400402
LinalgPromotionOptions options) {
401403
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
402404
// Transformation applies to buffers only.
403-
if (!linalgOp || !linalgOp.hasBufferSemantics())
405+
if (!linalgOp || !linalgOp.hasPureBufferSemantics())
404406
return failure();
405407
// Check that at least one of the requested operands is indeed a subview.
406408
for (OpOperand &opOperand : linalgOp->getOpOperands()) {

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ struct LinalgOpTilingInterface
212212
Location loc,
213213
ValueRange ivs) const {
214214
auto linalgOp = cast<LinalgOp>(op);
215-
if (!linalgOp.hasBufferSemantics())
215+
if (!linalgOp.hasPureBufferSemantics())
216216
return op->emitOpError("expected operation to have buffer semantics");
217217

218218
SmallVector<Value> indexedValues;
@@ -256,7 +256,7 @@ struct LinalgOpPartialReductionInterface
256256
auto linalgOp = cast<LinalgOp>(op);
257257
OpBuilder::InsertionGuard guard(b);
258258

259-
if (linalgOp.hasBufferSemantics())
259+
if (linalgOp.hasPureBufferSemantics())
260260
return op->emitOpError("expected operation to have tensor semantics");
261261
// Insert the new parallel dimension based on the index of the reduction
262262
// loops. This could be controlled by user for more flexibility.

0 commit comments

Comments
 (0)