Skip to content

Commit dc32665

Browse files
Fix tests
1 parent 6c53f21 commit dc32665

File tree

9 files changed

+167
-115
lines changed

9 files changed

+167
-115
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,10 @@ def Vector_ContractionOp :
214214
// IndexingMapOpInterface interface methods implementation.
215215
//===------------------------------------------------------------------===//
216216
ArrayRef<int64_t> getShape(OpOperand * opOperand) {
217-
assert(opOperand->getOwner() == this->getOperation());
218217
Type t = opOperand->get().getType();
219-
return cast<VectorType>(t).getShape();
218+
if (auto vt = dyn_cast<VectorType>(t))
219+
return vt.getShape();
220+
return {};
220221
}
221222
}];
222223

mlir/include/mlir/Interfaces/IndexingMapOpInterface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
#include "mlir/IR/BuiltinTypes.h"
1515
#include "mlir/IR/OpDefinition.h"
1616

17+
namespace mlir {
18+
namespace detail {
19+
/// Verify that `op` conforms to the invariants of StructuredOpInterface
20+
LogicalResult verifyIndexingMapOpInterface(Operation *op);
21+
} // namespace detail
22+
} // namespace mlir
23+
1724
/// Include the generated interface declarations.
1825
#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
1926

mlir/include/mlir/Interfaces/IndexingMapOpInterface.td

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,6 @@ def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
8383
return concatAffineMaps(maps, $_op.getContext());
8484
}]
8585
>,
86-
InterfaceMethod<
87-
/*desc=*/[{
88-
Like `getShape`, but only returns statically-known information, without
89-
generating any new IR. For each shape dimension, returns >=0 if that
90-
dimension is statically known, or ShapedType::kDynamic otherwise.
91-
}],
92-
/*retTy=*/"SmallVector<int64_t>",
93-
/*methodName=*/"getStaticShape",
94-
/*args=*/(ins),
95-
/*methodBody=*/"",
96-
/*defaultImplementation=*/[{
97-
SmallVector<int64_t> res;
98-
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
99-
llvm::append_range(res, $_op.getShape(&opOperand));
100-
return res;
101-
}]
102-
>,
10386
InterfaceMethod<
10487
/*desc=*/[{
10588
Hook to provide a custom AffineMap used to construct the
@@ -126,23 +109,46 @@ def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
126109
>,
127110
InterfaceMethod<
128111
/*desc=*/[{
129-
Returns the statically-known loop ranges. Composes
130-
`getShapesToLoopsMap()` with the result of `getStaticShape`.
112+
Returns the static shape of the underlying operand (note this is
113+
op-specific behavior).
114+
Returns ShapedType::kDynamic for non-statically-known loop ranges.
115+
}],
116+
/*retTy=*/"SmallVector<int64_t>",
117+
/*methodName=*/"getStaticOperandShape",
118+
/*args=*/(ins "OpOperand*":$opOperand),
119+
/*methodBody=*/"",
120+
/*defaultImplementation=*/[{
121+
SmallVector<int64_t> res;
122+
llvm::append_range(res, $_op.getShape(opOperand));
123+
return res;
124+
}]
125+
>,
126+
InterfaceMethod<
127+
/*desc=*/[{
128+
Returns loop ranges by composing `getShapesToLoopsMap()` with the
129+
flattened list of operand shapes.
131130
Returns ShapedType::kDynamic for non-statically-known loop ranges.
132-
This is expected to be called by a valid Linalg op
133131
}],
134-
/*retTy=*/"SmallVector<int64_t, 4>",
132+
/*retTy=*/"SmallVector<int64_t>",
135133
/*methodName=*/"getStaticLoopRanges",
136134
/*args=*/(ins),
137135
/*methodBody=*/"",
138136
/*defaultImplementation=*/[{
139-
SmallVector<int64_t> viewSizes = $_op.getStaticShape();
137+
SmallVector<int64_t> allShapesSizes;
138+
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
139+
llvm::append_range(allShapesSizes, $_op.getShape(&opOperand));
140140
AffineMap invertedMap = $_op.getShapesToLoopsMap();
141-
assert(invertedMap && "expected a valid Linalg op to call the method");
142-
return invertedMap.compose(viewSizes);
141+
assert(invertedMap && "expected a valid op");
142+
return invertedMap.compose(allShapesSizes);
143143
}]
144-
>,
144+
>
145145
];
146+
let extraClassDeclaration = [{
147+
// Verifier implementation for IndexingMapOpInterface.
148+
// This must be called manually as part of other verifiers so that the
149+
// verification order, and meaningful error messages, are not preempted.
150+
LogicalResult verifyImpl();
151+
}];
146152
}
147153

148-
#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
154+
#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 3 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,106 +1215,27 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12151215
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
12161216
return failure();
12171217

1218-
// All input/output operands must be indexed.
1219-
if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1220-
linalgOp->getNumOperands())
1221-
return op->emitOpError("expected the number of indexing_map (")
1222-
<< linalgOp.getIndexingMapsArray().size()
1223-
<< ") to be equal to the number of input/output operands ("
1224-
<< linalgOp->getNumOperands() << ")";
1218+
// Delayed calling of IndexingMapOpInterface::verifyImpl.
1219+
if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1220+
return failure();
12251221

12261222
// Set this flag if this op has user defined maps. This is required to guard
12271223
// the below error condition which assume default indexing maps.
12281224
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
12291225
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1230-
1231-
// Symbols disallowed.
1232-
if (indexingMap.getNumSymbols() != 0)
1233-
return op->emitOpError("unexpected symbols in indexing_map #")
1234-
<< opOperand.getOperandNumber();
1235-
12361226
// Domain must be consistent.
12371227
unsigned numLoops = linalgOp.getNumLoops();
12381228
if (indexingMap.getNumDims() != numLoops)
12391229
return op->emitOpError("expected indexing_map #")
12401230
<< opOperand.getOperandNumber() << " to have " << numLoops
12411231
<< " dim(s) to match the number of loops";
1242-
1243-
int64_t rank = linalgOp.getRank(&opOperand);
1244-
1245-
if (indexingMap.getNumResults() != rank)
1246-
return op->emitOpError("expected operand rank (")
1247-
<< rank << ") to match the result rank of indexing_map #"
1248-
<< opOperand.getOperandNumber() << " ("
1249-
<< indexingMap.getNumResults() << ")";
12501232
}
12511233
SmallVector<unsigned> redDims;
12521234
linalgOp.getReductionDims(redDims);
12531235

12541236
if (!linalgOp.getShapesToLoopsMap())
12551237
return op->emitOpError("expected the shape-to-loops map to be non-null");
12561238

1257-
// Check if given shapes match to inferred shapes.
1258-
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1259-
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1260-
// Verify only static cases since we can't get exact dimension sizes and
1261-
// loop ranges for dynamic cases in this stage.
1262-
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1263-
for (int64_t &range : endLoopRangeValues)
1264-
range -= 1;
1265-
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1266-
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1267-
SmallVector<int64_t, 4> startIndices =
1268-
indexingMap.compose(startLoopRangeValues);
1269-
SmallVector<int64_t, 4> endIndices =
1270-
indexingMap.compose(endLoopRangeValues);
1271-
ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1272-
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1273-
// Ignore dynamic dimension or the case that the dimension size is 0
1274-
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1275-
continue;
1276-
1277-
// The first index or last index should be the maximum or the minimum in
1278-
// the inferred index ranges since the range is increasing or
1279-
// decreasing. The size of dimensions of input/output operands and the
1280-
// maximum value + 1 in the inferred range should be the same. But, for
1281-
// now we check if the inferred ranges are in boundary of input/output
1282-
// operands' size or not in case that Affine Expressions are complicated
1283-
// such as d0 * 3
1284-
// + d1 since it is not easy to handle the issues.
1285-
// Found the case that this solution can't check, for example, (d0, d1)
1286-
// -> (d1 - d0)
1287-
int64_t inferredDimSize =
1288-
std::max(startIndices[dim], endIndices[dim]) + 1;
1289-
if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1290-
std::string mapStr;
1291-
{
1292-
llvm::raw_string_ostream os(mapStr);
1293-
os << indexingMap;
1294-
}
1295-
return op->emitOpError(
1296-
"unexpected result less than 0 at expression #")
1297-
<< dim << " in " << mapStr;
1298-
}
1299-
if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
1300-
if (inferredDimSize != shape[dim]) {
1301-
return op->emitOpError("inferred input/output operand #")
1302-
<< opOperand.getOperandNumber() << " has shape's dimension #"
1303-
<< dim << " to be " << inferredDimSize << ", but found "
1304-
<< shape[dim];
1305-
}
1306-
} else {
1307-
if (inferredDimSize > shape[dim]) {
1308-
return op->emitOpError("inferred input/output operand #")
1309-
<< opOperand.getOperandNumber() << " has shape's dimension #"
1310-
<< dim << " to be greater than or equal to "
1311-
<< inferredDimSize << ", but found " << shape[dim];
1312-
}
1313-
}
1314-
}
1315-
}
1316-
}
1317-
13181239
// Check the region has exactly one block.
13191240
if (linalgOp->getNumRegions() != 1 ||
13201241
!llvm::hasSingleElement(linalgOp->getRegion(0)))

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
398398
return rewriter.notifyMatchFailure(genericOp,
399399
"invalid indexing maps for operation");
400400
}
401-
SmallVector<int64_t> dims = genericOp.getStaticShape();
401+
402+
SmallVector<int64_t> allShapesSizes;
403+
for (OpOperand &opOperand : genericOp->getOpOperands())
404+
llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
402405

403406
// 1a. Get the allowed list of dimensions to drop from the `options`.
404407
SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
@@ -411,7 +414,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
411414
llvm::SmallDenseSet<unsigned> unitDims;
412415
for (const auto &expr : enumerate(invertedMap.getResults())) {
413416
if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
414-
if (dims[dimExpr.getPosition()] == 1 &&
417+
if (allShapesSizes[dimExpr.getPosition()] == 1 &&
415418
unitDimsFilter.count(expr.index()))
416419
unitDims.insert(expr.index());
417420
}

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "mlir/IR/BuiltinTypes.h"
3232
#include "mlir/IR/OpDefinition.h"
3333
#include "mlir/IR/PatternMatch.h"
34+
#include "mlir/IR/Value.h"
3435
#include "mlir/Support/LLVM.h"
3536
#include "mlir/Transforms/RegionUtils.h"
3637
#include "llvm/ADT/STLExtras.h"
@@ -2217,7 +2218,9 @@ static LogicalResult vectorizeLinalgOpPrecondition(
22172218
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
22182219
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
22192220
// tensor with dimension of 0 cannot be vectorized.
2220-
if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2221+
if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2222+
return llvm::is_contained(linalgOp.getShape(&operand), 0);
2223+
}))
22212224
return failure();
22222225
// Check API contract for input vector sizes.
22232226
if (!inputVectorSizes.empty() &&

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,8 @@ LogicalResult ContractionOp::verify() {
10631063
if (!isSupportedCombiningKind(getKind(), elementType))
10641064
return emitOpError("unsupported contraction type");
10651065

1066-
return success();
1066+
// Delayed calling of IndexingMapOpInterface::verifyImpl.
1067+
return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
10671068
}
10681069

10691070
// MaskableOpInterface methods.

mlir/lib/Interfaces/IndexingMapOpInterface.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,113 @@ using namespace mlir;
1313
namespace mlir {
1414
#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
1515
} // namespace mlir
16+
17+
LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
18+
// All input/output operands must be indexed.
19+
if (static_cast<int64_t>(getIndexingMapsArray().size()) !=
20+
getOperation()->getNumOperands())
21+
return this->emitOpError("expected the number of indexing_map (")
22+
<< getIndexingMapsArray().size()
23+
<< ") to be equal to the number of input/output operands ("
24+
<< getOperation()->getNumOperands() << ")";
25+
26+
AffineMap invertedMap = getShapesToLoopsMap();
27+
if (!invertedMap) {
28+
std::string str;
29+
llvm::raw_string_ostream os(str);
30+
getLoopsToShapesMap().print(os);
31+
return this->emitOpError("invalid indexing maps are non-invertible: ")
32+
<< "(" << str << ")";
33+
}
34+
35+
SmallVector<int64_t> endLoopRangeValues = getStaticLoopRanges();
36+
37+
// Set this flag if this op has user defined maps. This is required to guard
38+
// the below error condition which assume default indexing maps.
39+
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
40+
AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
41+
42+
// Symbols disallowed.
43+
if (indexingMap.getNumSymbols() != 0)
44+
return getOperation()->emitOpError("unexpected symbols in indexing_map #")
45+
<< opOperand.getOperandNumber();
46+
47+
// Domain must be consistent.
48+
if (indexingMap.getNumDims() != endLoopRangeValues.size())
49+
return getOperation()->emitOpError("expected indexing_map #")
50+
<< opOperand.getOperandNumber() << " to have "
51+
<< endLoopRangeValues.size()
52+
<< " dim(s) to match the number of loops";
53+
54+
SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
55+
int64_t rank = shape.size();
56+
57+
if (indexingMap.getNumResults() != rank)
58+
return getOperation()->emitOpError("expected operand rank (")
59+
<< rank << ") to match the result rank of indexing_map #"
60+
<< opOperand.getOperandNumber() << " ("
61+
<< indexingMap.getNumResults() << ")";
62+
}
63+
64+
// Check if given shapes match to inferred shapes.
65+
SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
66+
// Verify only static cases since we can't get exact dimension sizes and
67+
// loop ranges for dynamic cases in this stage.
68+
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
69+
// Exclusive end range.
70+
for (int64_t &range : endLoopRangeValues)
71+
range -= 1;
72+
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
73+
AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
74+
SmallVector<int64_t> startIndices =
75+
indexingMap.compose(startLoopRangeValues);
76+
SmallVector<int64_t> endIndices = indexingMap.compose(endLoopRangeValues);
77+
SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
78+
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
79+
// Ignore dynamic dimension or the case that the dimension size is 0
80+
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
81+
continue;
82+
83+
// The first index or last index should be the maximum or the minimum in
84+
// the inferred index ranges since the range is increasing or
85+
// decreasing. The size of dimensions of input/output operands and the
86+
// maximum value + 1 in the inferred range should be the same. But, for
87+
// now we check if the inferred ranges are in boundary of input/output
88+
// operands' size or not in case that Affine Expressions are complicated
89+
// such as d0 * 3
90+
// + d1 since it is not easy to handle the issues.
91+
// Found the case that this solution can't check, for example, (d0, d1)
92+
// -> (d1 - d0)
93+
int64_t inferredDimSize =
94+
std::max(startIndices[dim], endIndices[dim]) + 1;
95+
if (std::min(startIndices[dim], endIndices[dim]) < 0) {
96+
std::string mapStr;
97+
{
98+
llvm::raw_string_ostream os(mapStr);
99+
os << indexingMap;
100+
}
101+
return this->emitOpError(
102+
"unexpected result less than 0 at expression #")
103+
<< dim << " in " << mapStr;
104+
}
105+
if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
106+
if (inferredDimSize != shape[dim]) {
107+
return this->emitOpError("inferred input/output operand #")
108+
<< opOperand.getOperandNumber() << " has shape's dimension #"
109+
<< dim << " to be " << inferredDimSize << ", but found "
110+
<< shape[dim];
111+
}
112+
} else {
113+
if (inferredDimSize > shape[dim]) {
114+
return this->emitOpError("inferred input/output operand #")
115+
<< opOperand.getOperandNumber() << " has shape's dimension #"
116+
<< dim << " to be greater than or equal to "
117+
<< inferredDimSize << ", but found " << shape[dim];
118+
}
119+
}
120+
}
121+
}
122+
}
123+
124+
return success();
125+
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off
151151
// -----
152152

153153
func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
154-
// expected-error @+1 {{expected the shape-to-loops map to be non-null}}
154+
// expected-error @+1 {{invalid indexing maps are non-invertible: ((d0, d1) -> (d0 + d1, d0 + d1))}}
155155
linalg.generic {
156156
indexing_maps = [
157157
affine_map<(i, j) -> (i + j)>,

0 commit comments

Comments
 (0)