Skip to content

Commit 8b525c9

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add utility function that return static loop bounds of Linalg ops
Differential Revision: https://reviews.llvm.org/D91749
1 parent b2f6630 commit 8b525c9

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,23 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
114114
return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
115115
}
116116

117+
/// Like `getShape`, but only returns statically-known information, without
118+
/// generating any new IR. For each shape dimension, returns >=0 if that
119+
/// dimension is statically known, or -1 otherwise.
120+
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
121+
117122
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
118123
/// concatenated indexing maps to the result of `getShape`. Returns None if
119124
/// the bounds computation fails.
120125
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
121126
LinalgOp linalgOp);
122127

128+
/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
129+
/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
130+
/// Returns None if inverting the concatenated indexing map fails. Returns -1
131+
/// for non-statically-known loop ranges.
132+
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
133+
123134
/// Returns the values obtained by applying `map` to the list of values.
124135
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
125136
AffineMap map, ValueRange values);

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
156156
return res;
157157
}
158158

159+
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
160+
SmallVector<int64_t, 8> res;
161+
for (Value v : linalgOp.getShapedOperands()) {
162+
auto shape = v.getType().cast<ShapedType>().getShape();
163+
res.append(shape.begin(), shape.end());
164+
}
165+
return res;
166+
}
167+
159168
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
160169
LinalgOp linalgOp) {
161170
SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
@@ -166,6 +175,15 @@ Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
166175
return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
167176
}
168177

178+
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
179+
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
180+
AffineMap invertedMap =
181+
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
182+
if (!invertedMap)
183+
return {};
184+
return invertedMap.compose(viewSizes);
185+
}
186+
169187
/// Specialization to build an scf "for" nest.
170188
template <>
171189
void GenerateLoopNest<scf::ForOp>::doit(

0 commit comments

Comments
 (0)