Skip to content

Commit a2a4bc5

Browse files
author
Tobias Gysi
committed
[mlir][linalg] All StructuredOp parameters are inputs or outputs.
Adapt the StructuredOp verifier to ensure all operands are either in the input or the output group. The change is possible after adding support for scalar input operands (https://reviews.llvm.org/D104220). Differential Revision: https://reviews.llvm.org/D104783
1 parent d156637 commit a2a4bc5

File tree

8 files changed

+9
-41
lines changed

8 files changed

+9
-41
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
253253
/*args=*/(ins),
254254
/*methodBody=*/"",
255255
/*defaultImplementation=*/[{
256-
return getNumInputs() + getNumOutputs();
256+
return this->getOperation()->getNumOperands();
257257
}]
258258
>,
259259
//===------------------------------------------------------------------===//
@@ -346,8 +346,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
346346
result.reserve(numOutputs);
347347
llvm::transform(
348348
this->getOperation()->getOpOperands()
349-
.drop_front(getNumInputs())
350-
.take_front(numOutputs),
349+
.take_back(numOutputs),
351350
std::back_inserter(result),
352351
[](OpOperand &opOperand) { return &opOperand; });
353352
return result;
@@ -458,8 +457,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
458457
OpOperandVector result;
459458
result.reserve(numInputsAndOutputs);
460459
llvm::transform(
461-
this->getOperation()->getOpOperands()
462-
.take_front(numInputsAndOutputs),
460+
this->getOperation()->getOpOperands(),
463461
std::back_inserter(result),
464462
[](OpOperand &opOperand) { return &opOperand; });
465463
return result;
@@ -928,22 +926,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
928926
/// `createFlatListOfOperandStaticDims`.
929927
SmallVector<int64_t, 4> computeStaticLoopSizes();
930928

931-
/// Returns all the operands past the inputs, output_buffers and
932-
/// init_tensors operands. Asserts that these operands are value types to
933-
/// allow transformations like tiling to just use the values when cloning
934-
/// `linalgOp`.
935-
Operation::operand_range getAssumedNonShapedOperands() {
936-
Operation::operand_range res{
937-
getOperation()->getOperands().begin() + getNumInputsAndOutputs(),
938-
getOperation()->getOperands().end()};
939-
for (Type t : TypeRange{res}) {
940-
(void)t;
941-
assert((t.isSignlessIntOrIndexOrFloat() || t.template isa<VectorType>())
942-
&&"expected scalar or vector type");
943-
}
944-
return res;
945-
}
946-
947929
/// Returns the value that expresses the shape of the output in terms of
948930
/// shape of the input operands where possible
949931
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,15 @@ LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
318318

319319
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
320320
LinalgOp linalgOp = cast<LinalgOp>(op);
321-
// Expect at least one input/output operand.
321+
// Expect at least one output operand.
322322
// This means an op that constructs a tensor out of indices cannot be a
323323
// LinalgOp at the moment. For now this will have to be a special op until we
324324
// have output shape operands that are not tensors.
325-
int64_t numInputsAndOutputs = linalgOp.getNumInputsAndOutputs();
326-
if (numInputsAndOutputs == 0)
327-
return op->emitOpError("expected at least one input/output operand");
328-
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, numInputsAndOutputs)))
325+
int64_t numInputs = linalgOp.getNumInputs();
326+
int64_t numOutputs = linalgOp.getNumOutputs();
327+
if (numOutputs == 0)
328+
return op->emitOpError("expected at least one output operand");
329+
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
329330
return failure();
330331
// Should have at least one output tensor per result tensor.
331332
// Can also have outbut buffers that do not correspond to results.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,8 +3038,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
30383038
: opOperand->get());
30393039
newResultTypes.push_back(newOperands.back().getType());
30403040
}
3041-
auto extraOperands = op.getAssumedNonShapedOperands();
3042-
newOperands.append(extraOperands.begin(), extraOperands.end());
30433041
// Clone op.
30443042
Operation *newOp =
30453043
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
@@ -3109,7 +3107,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
31093107
newOperands.push_back(opOperand->get());
31103108
SmallVector<Value> outputOperands = op.getOutputOperands();
31113109
llvm::append_range(newOperands, outputOperands);
3112-
llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
31133110

31143111
// Repair the indexing maps by filtering out the ones that have been
31153112
// eliminated.

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
119119
assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
120120
SmallVector<Value, 8> newOperands = inputs;
121121
newOperands.append(outputs.begin(), outputs.end());
122-
auto otherOperands = linalgOp.getAssumedNonShapedOperands();
123-
newOperands.append(otherOperands.begin(), otherOperands.end());
124122
linalgOp.clone(rewriter, linalgOp.getLoc(),
125123
/*resultTypes=*/ArrayRef<Type>{}, newOperands);
126124
// Replace the results of the old op with the new output buffers.

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,6 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
12411241
// Clone the newly bufferized op.
12421242
SmallVector<Value> newOperands = newInputBuffers;
12431243
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
1244-
auto otherOperands = op.getAssumedNonShapedOperands();
1245-
newOperands.append(otherOperands.begin(), otherOperands.end());
12461244
op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
12471245

12481246
// Replace the results of the old op with the new output buffers.

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
205205
getTiledOperands(b, producer), ivs,
206206
tileSizes, sizeBounds));
207207

208-
// Append the other operands.
209-
auto operands = producer.getAssumedNonShapedOperands();
210-
clonedShapes.append(operands.begin(), operands.end());
211-
212208
// Iterate over the results in order.
213209
// Extract the subtensor type from the linearized range.
214210
// Since we do not enforce any canonicalizations on the fly, this is always

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
242242
applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
243243
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
244244
b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
245-
auto nonShapedOperands = op.getAssumedNonShapedOperands();
246-
tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
247245

248246
// TODO: use an interface/adaptor to avoid leaking position in
249247
// `tiledOperands`.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
190190
// Clone `opToPad` to operate on the statically padded shapes.
191191
auto resultTensorTypes =
192192
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
193-
ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
194-
newOperands.append(otherOperands.begin(), otherOperands.end());
195193
linalg::LinalgOp paddedOp =
196194
opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
197195

0 commit comments

Comments
 (0)