Skip to content

Commit 9b05170

Browse files
author
MaheshRavishankar
committed
[mlir] Enhance InferShapedTypeOpInterface and move LinalgOps to use them.
A new `InterfaceMethod` is added to `InferShapedTypeOpInterface` that allows an operation to return the `Value`s for each dim of its results. It is intended for the case where the `Value` returned for each dim is computed using the operands and operation attributes. This interface method is for cases where the result dim of an operation can be computed independently, and it avoids the need to aggregate all dims of a result into a single shape value. This also implies that this is not suitable for cases where the result type is unranked (for which the existing interface methods is to be used). Also added is a canonicalization pattern that uses this interface and resolves the shapes of the output in terms of the shapes of the inputs. Moving Linalg ops to use this interface, so that many canonicalization patterns implemented for individual linalg ops to achieve the same result can be removed in favor of the added canonicalization pattern. Differential Revision: https://reviews.llvm.org/D97887
1 parent 742f663 commit 9b05170

File tree

15 files changed

+395
-209
lines changed

15 files changed

+395
-209
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,18 +1087,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10871087
>,
10881088
InterfaceMethod<
10891089
/*desc=*/[{
1090-
Return the position in the results of the affine map computed
1091-
by getLoopsToShapesMap() that represents the shape of the
1092-
result value at a dimension.
1090+
Return the range of position in the result of the affine map
1091+
computed by getLoopsToShapesMap() which correspond to the
1092+
AffineExprs used to access the outputs of the operation.
10931093
}],
1094-
/*retTy=*/"Optional<unsigned>",
1095-
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
1096-
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
1094+
/*retTy=*/"std::pair<unsigned, unsigned>",
1095+
/*methodName=*/"getResultsPositionInLoopsToShapeMap",
1096+
/*args=*/(ins),
10971097
/*methodBody=*/"",
10981098
/*defaultImplementation=*/[{
1099-
if (resultIdx >= getNumOutputs()) return {};
1100-
return getOperandDimPositionInLoopsToShapeMap(
1101-
getNumInputs() + resultIdx, dim);
1099+
return
1100+
{*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0),
1101+
(*getOperandDimPositionInLoopsToShapeMap
1102+
(getNumInputs() + getNumOutputs() - 1,
1103+
getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1};
11021104
}]
11031105
>,
11041106
InterfaceMethod<
@@ -1226,8 +1228,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
12261228

12271229
/// Returns the value that expresses the shape of the output in terms of
12281230
/// shape of the input operands where possible
1229-
Optional<Value> inferResultDimFromInputShapes
1230-
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
1231+
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
1232+
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
12311233

12321234
//========================================================================//
12331235
// Helper functions to mutate the `operand_segment_sizes` attribute.

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/TypeUtilities.h"
2323
#include "mlir/IR/Types.h"
2424
#include "mlir/Interfaces/CopyOpInterface.h"
25+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2526
#include "mlir/Interfaces/SideEffectInterfaces.h"
2627
#include "mlir/Interfaces/ViewLikeInterface.h"
2728
#include "mlir/Support/LLVM.h"
@@ -107,13 +108,6 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
107108
void getDimsOfType(Operation *op, StringRef iteratorTypeName,
108109
SmallVectorImpl<AffineExpr> &res);
109110

110-
/// For reshape operation, compute the shape of the output based on the result
111-
/// type and shape of the input.
112-
SmallVector<Value, 4>
113-
getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
114-
ArrayRef<int64_t> dstStaticShape,
115-
ArrayRef<AffineMap> reassociation);
116-
117111
namespace detail {
118112
LogicalResult verifyStructuredOpInterface(Operation *op);
119113
} // namespace detail

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1717
include "mlir/Interfaces/ControlFlowInterfaces.td"
18+
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/LoopLikeInterface.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
2021
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -33,7 +34,10 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
3334
let parser = [{ return ::parse$cppClass(parser, result); }];
3435
}
3536

36-
def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
37+
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
38+
[NoSideEffect,
39+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
40+
["reifyReturnTypeShapesPerResultDim"]>]> {
3741
let summary = "operation to define a tensor of particular value";
3842

3943
let description = [{
@@ -126,7 +130,10 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
126130
}
127131

128132
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
129-
[AttrSizedOperandSegments, NoSideEffect]> {
133+
[AttrSizedOperandSegments,
134+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
135+
["reifyReturnTypeShapesPerResultDim"]>,
136+
NoSideEffect]> {
130137
let summary = "tensor pad operation";
131138
let description = [{
132139
`linalg.pad_tensor` is an operation that pads the `source` tensor
@@ -348,11 +355,6 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
348355
a.cast<AffineMapAttr>().getValue().getResults());
349356
}));
350357
}
351-
SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
352-
return getReshapeOutputShapeFromInputShape(
353-
b, loc, src(), getResultType().getShape(),
354-
getReassociationMaps());
355-
}
356358
}];
357359
let assemblyFormat = [{
358360
$src $reassociation attr-dict `:` type($src) `into` type(results)
@@ -417,7 +419,10 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
417419
let hasCanonicalizer = 1;
418420
}
419421

420-
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
422+
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
423+
"tensor_reshape",
424+
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
425+
["reifyReturnTypeShapesPerResultDim"]>]>,
421426
Arguments<(ins AnyTensor:$src,
422427
AffineMapArrayAttr:$reassociation)>,
423428
Results<(outs AnyTensor:$result)> {

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1818
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1919
include "mlir/Interfaces/CopyOpInterface.td"
20+
include "mlir/Interfaces/InferTypeOpInterface.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
2122

2223
// Base Tablegen class for Linalg ops.
@@ -25,14 +26,20 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2526
// depending on the specific Linalg op.
2627
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
2728
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
28-
LinalgStructuredInterface])> {
29+
LinalgStructuredInterface, InferShapedTypeOpInterface])> {
2930
code structuredOpsBaseDecls = [{
3031
// Return the number of induction variables in the basic block. This should
3132
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
3233
// equal to numLoops.
3334
unsigned getNumPayloadInductionVariables() {
3435
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
3536
}
37+
38+
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
39+
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
40+
return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
41+
reifiedReturnShapes);
42+
}
3643
}];
3744
}
3845

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,53 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
9797
"::mlir::DictionaryAttr":$attributes,
9898
"::mlir::RegionRange":$regions,
9999
"::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
100-
$inferredReturnShapes)
100+
$inferredReturnShapes),
101+
/*methodBody=*/[{}],
102+
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
101103
>,
102104
InterfaceMethod<
103105
/*desc=*/[{Reify the shape computation for the operation.
104106

105-
Insert operations using the given OpBuilder that computes the result
106-
shape.
107+
Insert operations using the given OpBuilder that computes the
108+
result shape. Only one of this method or
109+
`reifyReturnTypeShapesPerResultDim` needs to be overriden by the
110+
operation.
107111
}],
108112
/*retTy=*/"::mlir::LogicalResult",
109113
/*methodName=*/"reifyReturnTypeShapes",
110114
/*args=*/(ins "::mlir::OpBuilder&":$builder,
111-
"::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes),
115+
"::mlir::SmallVectorImpl<Value> &":$reifiedReturnShapes),
112116
/*methodBody=*/[{}],
113117
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
114118
>,
119+
InterfaceMethod<
120+
/*desc=*/[{Reify the shape computation for the operation.
121+
122+
Insert operations using the given OpBuilder that computes the
123+
result shape. The `reifiedReturnShapes` is expected to be
124+
populated with as many vectors as the number of results of the
125+
op (empty if the shape of a result value cannot be computed). If
126+
the returned shape for a result is not empty, its size must
127+
match the rank of the shaped type returned. Consequently, this
128+
interface can only be overridden if the return types are ranked.
129+
130+
If both this method and `reifyReturnTypeShapes` are overridden
131+
by the operation, `reifyReturnTypeShapes` takes precedence. This
132+
method is intended to be used when the shape of each result, dim
133+
pair can be computed independently. Using this method avoids
134+
adding additional instructions to aggregate individual dimension
135+
of a result shape into an single `Value` (and consequently
136+
avoids the need to extract the value from the shape on the
137+
client side).
138+
}],
139+
/*retTy=*/"::mlir::LogicalResult",
140+
/*methodName=*/"reifyReturnTypeShapesPerResultDim",
141+
/*args=*/(ins "::mlir::OpBuilder&":$builder,
142+
"::mlir::SmallVectorImpl<SmallVector<::mlir::Value>>&"
143+
:$reifiedReturnShapes),
144+
/*methodBody=*/[{}],
145+
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
146+
>
115147
];
116148
}
117149

@@ -129,6 +161,7 @@ class InferTensorType<list<string> overridenMethods = []> {
129161
NativeOpTrait<"InferTensorType">
130162
];
131163
}
132-
defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
164+
defvar InferTensorTypeWithReify = InferTensorType<[
165+
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
133166

134167
#endif // MLIR_INFERTYPEOPINTERFACE

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
1414
LINK_LIBS PUBLIC
1515
MLIRAffine
1616
MLIRDialectUtils
17+
MLIRInferTypeOpInterface
1718
MLIRIR
1819
MLIRParser
1920
MLIRSideEffectInterfaces

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
188188
for (Value v : getShapedOperands()) {
189189
ShapedType t = v.getType().template cast<ShapedType>();
190190
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
191-
res.push_back(b.create<memref::DimOp>(loc, v, i));
191+
res.push_back(b.createOrFold<memref::DimOp>(loc, v, i));
192192
}
193193
return res;
194194
}
@@ -234,57 +234,58 @@ struct HasAffineDimExprVisitor
234234
llvm::SmallSet<unsigned, 4> positions;
235235
};
236236

237-
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
238-
Location loc,
239-
unsigned resultIdx,
240-
unsigned dim) {
237+
LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
238+
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
241239
// An example that helps understand the logic below.
242240
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
243241
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
244242
// This is achieved as follows.
245243
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
246-
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
244+
// subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
247245
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
248-
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
249-
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
246+
// resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
247+
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
250248
AffineMap loopsToShapesMap = getLoopsToShapesMap();
251249

252250
// Find the position in the above map that represents the shape of the
253251
// result:dim being inferred.
254-
Optional<unsigned> resultDimSubMapPos =
255-
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
256-
if (!resultDimSubMapPos)
257-
return {};
252+
auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
258253

259254
/// From loopsToShapesMap extract the submap that represents the shape of the
260-
/// (resultIdx, dim) needed
261-
AffineMap loopToResultDimShapeMap =
262-
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
263-
AffineMap operandShapesToResultDimMap =
264-
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
255+
/// (resultIdx, dim) needed.
256+
SmallVector<unsigned, 4> resultPosRange =
257+
llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
258+
resultShapesSubMapPos.second));
259+
AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
260+
AffineMap resultShapesFromInputShapesMap =
261+
loopToResultsShapeMap.compose(getShapesToLoopsMap());
265262

266263
// Check that the result dim map does not contain the positions corresponding
267264
// to the outputs.
268265
llvm::SmallSet<unsigned, 4> outputDims;
269-
unsigned outputDimPosStart =
270-
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
271-
unsigned outputDimPosEnd =
272-
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
273-
getOutputOpOperands()
274-
.back()
275-
.get()
276-
.getType()
277-
.cast<ShapedType>()
278-
.getRank() -
279-
1)
280-
.getValue();
281-
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
266+
llvm::for_each(resultPosRange,
282267
[&outputDims](unsigned dim) { outputDims.insert(dim); });
283268
HasAffineDimExprVisitor checkDimExpr(outputDims);
284-
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
285-
return llvm::None;
286-
return applyMapToValues(b, loc, operandShapesToResultDimMap,
287-
createFlatListOfOperandDims(b, loc))[0];
269+
Location loc = getOperation()->getLoc();
270+
auto allResultDimValues =
271+
applyMapToValues(b, loc, resultShapesFromInputShapesMap,
272+
createFlatListOfOperandDims(b, loc));
273+
unsigned pos = 0;
274+
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
275+
for (auto resultIdx : llvm::seq<unsigned>(0, getNumOutputs())) {
276+
ShapedType resultType = getOutputShapedType(resultIdx);
277+
SmallVector<Value> shapes;
278+
for (unsigned dim : llvm::seq<unsigned>(0, resultType.getRank())) {
279+
if (checkDimExpr.visit(shapeExprs[pos]))
280+
shapes.push_back(
281+
b.createOrFold<memref::DimOp>(loc, getOutput(resultIdx), dim));
282+
else
283+
shapes.push_back(allResultDimValues[pos]);
284+
pos++;
285+
}
286+
reifiedReturnShapes.emplace_back(std::move(shapes));
287+
}
288+
return success();
288289
}
289290

290291
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {

0 commit comments

Comments
 (0)