Skip to content

Commit 203fad4

Browse files
[mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations
Differential Revision: https://reviews.llvm.org/D145784
1 parent c113d0b commit 203fad4

File tree

11 files changed

+348
-145
lines changed

11 files changed

+348
-145
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 140 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,68 @@
2323
namespace mlir {
2424
class ArrayAttr;
2525

26-
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
27-
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
28-
29-
/// Given the strides together with a linear index in the dimension
30-
/// space, returns the vector-space offsets in each dimension for a
31-
/// de-linearized index.
32-
SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
33-
int64_t linearIndex);
26+
//===----------------------------------------------------------------------===//
27+
// Utils that operate on static integer values.
28+
//===----------------------------------------------------------------------===//
3429

35-
/// Given a set of sizes, compute and return the strides (i.e. the number of
36-
/// linear incides to skip along the (k-1) most minor dimensions to get the next
37-
/// k-slice). This is also the basis that one can use to linearize an n-D offset
38-
/// confined to `[0 .. sizes]`.
39-
SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
30+
/// Given a set of sizes, return the suffix product.
31+
///
32+
/// When applied to slicing, this is the calculation needed to derive the
33+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
34+
/// minor dimensions to get the next k-slice).
35+
///
36+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
37+
///
38+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
39+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
40+
///
41+
/// `sizes` elements are asserted to be non-negative.
42+
///
43+
/// Return an empty vector if `sizes` is empty.
44+
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
45+
inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
46+
return computeSuffixProduct(sizes);
47+
}
4048

41-
/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
49+
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
50+
///
51+
/// Return an empty vector if `v1` and `v2` are empty.
4252
SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
4353
ArrayRef<int64_t> v2);
4454

45-
/// Compute and return the multi-dimensional integral ratio of `subShape` to
46-
/// the trailing dimensions of `shape`. This represents how many times
47-
/// `subShape` fits within `shape`.
48-
/// If integral division is not possible, return std::nullopt.
55+
/// Return the number of elements of basis (i.e. the max linear index).
56+
/// Return `0` if `basis` is empty.
57+
///
58+
/// `basis` elements are asserted to be non-negative.
59+
///
60+
/// Return `0` if `basis` is empty.
61+
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
62+
63+
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
64+
///
65+
/// `basis` elements are asserted to be non-negative.
66+
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
67+
68+
/// Given the strides together with a linear index in the dimension space,
69+
/// return the vector-space offsets in each dimension for a de-linearized index.
70+
/// `strides` elements are asserted to be non-negative.
71+
///
72+
/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
73+
/// vector of int64_t
74+
/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
75+
SmallVector<int64_t> delinearize(int64_t linearIndex,
76+
ArrayRef<int64_t> strides);
77+
78+
/// Return the multi-dimensional integral ratio of `subShape` to the trailing
79+
/// dimensions of `shape`. This represents how many times `subShape` fits
80+
/// within `shape`. If integral division is not possible, return std::nullopt.
4981
/// The trailing `subShape.size()` entries of both shapes are assumed (and
50-
/// enforced) to only contain noonnegative values.
82+
/// enforced) to only contain non-negative values.
5183
///
5284
/// Examples:
5385
/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
54-
/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has higher
86+
/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has
87+
/// higher
5588
/// rank).
5689
/// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
5790
/// derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
@@ -60,14 +93,96 @@ SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
6093
std::optional<SmallVector<int64_t>>
6194
computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape);
6295

96+
//===----------------------------------------------------------------------===//
97+
// Utils that operate on AffineExpr.
98+
//===----------------------------------------------------------------------===//
99+
100+
/// Given a set of sizes, return the suffix product.
101+
///
102+
/// When applied to slicing, this is the calculation needed to derive the
103+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
104+
/// minor dimensions to get the next k-slice).
105+
///
106+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
107+
///
108+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<AffineExpr>
109+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
110+
///
111+
/// It is the caller's responsibility to pass proper AffineExpr kind that
112+
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
113+
/// by an AffineDimExpr).
114+
///
115+
/// `sizes` elements are expected to bind to non-negative values.
116+
///
117+
/// Return an empty vector if `sizes` is empty.
118+
SmallVector<AffineExpr> computeSuffixProduct(ArrayRef<AffineExpr> sizes);
119+
inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
120+
return computeSuffixProduct(sizes);
121+
}
122+
123+
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
124+
///
125+
/// It is the caller's responsibility to pass proper AffineExpr kind that
126+
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
127+
/// by an AffineDimExpr).
128+
///
129+
/// Return an empty vector if `v1` and `v2` are empty.
130+
SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
131+
ArrayRef<AffineExpr> v2);
132+
63133
/// Return the number of elements of basis (i.e. the max linear index).
64134
/// Return `0` if `basis` is empty.
65-
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
135+
///
136+
/// It is the caller's responsibility to pass proper AffineExpr kind that
137+
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
138+
/// by an AffineDimExpr).
139+
///
140+
/// `basis` elements are expected to bind to non-negative values.
141+
///
142+
/// Return the `0` AffineConstantExpr if `basis` is empty.
143+
AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
144+
145+
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
146+
///
147+
/// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the
148+
/// AffineExpr `o0 * b0 + .. + on * bn`.
149+
///
150+
/// It is the caller's responsibility to pass proper AffineExpr kind that result
151+
/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
152+
/// AffineDimExpr).
153+
///
154+
/// `basis` elements are expected to bind to non-negative values.
155+
AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
156+
ArrayRef<AffineExpr> basis);
157+
AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
158+
ArrayRef<int64_t> basis);
159+
160+
/// Given the strides together with a linear index in the dimension space,
161+
/// return the vector-space offsets in each dimension for a de-linearized index.
162+
///
163+
/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
164+
/// vector of AffineExpr
165+
/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
166+
///
167+
/// It is the caller's responsibility to pass proper AffineExpr kind that result
168+
/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
169+
/// AffineDimExpr).
170+
///
171+
/// `strides` elements are expected to bind to non-negative values.
172+
SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
173+
ArrayRef<AffineExpr> strides);
174+
SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
175+
ArrayRef<int64_t> strides);
176+
177+
//===----------------------------------------------------------------------===//
178+
// Permutation utils.
179+
//===----------------------------------------------------------------------===//
66180

67181
/// Apply the permutation defined by `permutation` to `inVec`.
68182
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
69-
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
70-
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
183+
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
184+
/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
185+
/// 'b']`.
71186
template <typename T, unsigned N>
72187
void applyPermutationToVector(SmallVector<T, N> &inVec,
73188
ArrayRef<int64_t> permutation) {
@@ -83,18 +198,11 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
83198
/// Method to check if an interchange vector is a permutation.
84199
bool isPermutationVector(ArrayRef<int64_t> interchange);
85200

86-
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
201+
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
202+
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
87203
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
88204
unsigned dropBack = 0);
89205

90-
/// Computes and returns linearized affine expression w.r.t. `basis`.
91-
mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
92-
93-
/// Given the strides in the dimension space, returns the affine expressions for
94-
/// vector-space offsets in each dimension for a de-linearized index.
95-
SmallVector<mlir::AffineExpr>
96-
getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
97-
98206
} // namespace mlir
99207

100208
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,6 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
321321
bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
322322
}
323323

324-
template <typename AffineExprTy>
325-
void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl<AffineExprTy> &exprs) {
326-
int idx = 0;
327-
for (AffineExprTy &e : exprs)
328-
e = getAffineSymbolExpr(idx++, ctx);
329-
}
330-
331324
} // namespace detail
332325

333326
/// Bind a list of AffineExpr references to DimExpr at positions:
@@ -337,13 +330,27 @@ void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
337330
detail::bindDims<0>(ctx, exprs...);
338331
}
339332

333+
template <typename AffineExprTy>
334+
void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
335+
int idx = 0;
336+
for (AffineExprTy &e : exprs)
337+
e = getAffineDimExpr(idx++, ctx);
338+
}
339+
340340
/// Bind a list of AffineExpr references to SymbolExpr at positions:
341341
/// [0 .. sizeof...(exprs)]
342342
template <typename... AffineExprTy>
343343
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
344344
detail::bindSymbols<0>(ctx, exprs...);
345345
}
346346

347+
template <typename AffineExprTy>
348+
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
349+
int idx = 0;
350+
for (AffineExprTy &e : exprs)
351+
e = getAffineSymbolExpr(idx++, ctx);
352+
}
353+
347354
} // namespace mlir
348355

349356
namespace llvm {

mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
103103
loc, DenseElementsAttr::get(vecType, initValueAttr));
104104
SmallVector<int64_t> strides = computeStrides(shape);
105105
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
106-
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
106+
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
107107
SmallVector<Value> operands;
108108
for (Value input : op->getOperands())
109109
operands.push_back(

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
8989
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
9090
SmallVector<int64_t> strides = computeStrides(shape);
9191
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
92-
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
92+
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
9393
SmallVector<Value> operands;
9494
for (auto input : op->getOperands())
9595
operands.push_back(

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
134134
SmallVector<Value> results(maxIndex);
135135

136136
for (int64_t i = 0; i < maxIndex; ++i) {
137-
auto offsets = delinearize(strides, i);
137+
auto offsets = delinearize(i, strides);
138138

139139
SmallVector<Value> extracted(expandedOperands.size());
140140
for (const auto &tuple : llvm::enumerate(expandedOperands))
@@ -152,7 +152,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
152152

153153
for (int64_t i = 0; i < maxIndex; ++i)
154154
result = builder.create<vector::InsertOp>(results[i], result,
155-
delinearize(strides, i));
155+
delinearize(i, strides));
156156

157157
// Reshape back to the original vector shape.
158158
return builder.create<vector::ShapeCastOp>(

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
7575
SmallVector<OpFoldResult> values(2 * sourceRank + 1);
7676
SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
7777

78-
detail::bindSymbolsList(rewriter.getContext(), symbols);
78+
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
7979
AffineExpr expr = symbols.front();
8080
values[0] = ShapedType::isDynamic(sourceOffset)
8181
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@@ -262,10 +262,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
262262
auto sourceType = source.getType().cast<MemRefType>();
263263
auto [strides, offset] = getStridesAndOffset(sourceType);
264264

265-
OpFoldResult origStride =
266-
ShapedType::isDynamic(strides[groupId])
267-
? origStrides[groupId]
268-
: builder.getIndexAttr(strides[groupId]);
265+
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
266+
? origStrides[groupId]
267+
: builder.getIndexAttr(strides[groupId]);
269268

270269
// Apply the original stride to all the strides.
271270
int64_t doneStrideIdx = 0;

0 commit comments

Comments
 (0)