Skip to content

Commit 31aa7f3

Browse files
authored
[mlir][Affine] Let affine.[de]linearize_index omit outer bounds (#116103)
The affine.delinearize_index and affine.linearize_index operations, as currently defined, require providing a length N basis to [de]linearize N values. The first value in this basis is never used during lowering and is unused during lowering. (Note that, even though it isn't used during lowering it can still be used to, for example, remove length-1 outputs from a delinearize). This dead value makes sense in the original context of these operations, which is linearizing or de-linearizing indexes to memref<>s, vector<>s, and other shaped types, where that outer bound is avaliable and may be useful for analysis. However, other usecases exist where the outer bound is not known. For example: %thread_id_x = gpu.thread_id x : index %0:3 = affine.delinearize_index %thread_id_x into (4, 16) : index,index, index In this code, we don't know the upper bound of the thread ID, but we do want to construct the ?x4x16 grid of delinearized values in order to further partition the GPU threads. In order to support such usecases, we broaden the definition of affine.delinearize_index and affine.linearize_index to make the outer bound optional. In the case of affine.delinearize_index, where the number of results is a function of the size of the passed-in basis, we augment all existing builders with a `hasOuterBound` argument, which, for backwards compatibilty and to preserve the natural usage of the op, defaults to `true`. If this flag is true, the op returns one result per basis element, if it is false, it returns one extra result in position 0. We also update existing canonicalization patterns (and move one of them into the folder) to handle these cases. Note that disagreements about the outer bound now no longer prevent delinearize/linearize cancelations.
1 parent f8d1905 commit 31aa7f3

File tree

9 files changed

+402
-127
lines changed

9 files changed

+402
-127
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,8 +1060,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
10601060
// AffineDelinearizeIndexOp
10611061
//===----------------------------------------------------------------------===//
10621062

1063-
def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
1064-
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1063+
def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
10651064
let summary = "delinearize an index";
10661065
let description = [{
10671066
The `affine.delinearize_index` operation takes a single index value and
@@ -1083,6 +1082,25 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10831082
%indices_1 = affine.apply #map1()[%linear_index]
10841083
%indices_2 = affine.apply #map2()[%linear_index]
10851084
```
1085+
1086+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
1087+
If there are N basis elements, the first one will not be used during computations,
1088+
but may be used during analysis and canonicalization to eliminate terms from
1089+
the `affine.delinearize_index` or to enable conclusions about the total size of
1090+
`%linear_index`.
1091+
1092+
If the basis is fully provided, the delinearize_index operation is said to "have
1093+
an outer bound". The builders assume that an `affine.delinearize_index` has
1094+
an outer bound by default, as this is how the operation was initially defined.
1095+
1096+
That is, the example above could also have been written
1097+
```mlir
1098+
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
1099+
```
1100+
1101+
Note that, due to the constraints of affine maps, all the basis elements must
1102+
be strictly positive. A dynamic basis element being 0 or negative causes
1103+
undefined behavior.
10861104
}];
10871105

10881106
let arguments = (ins Index:$linear_index,
@@ -1097,17 +1115,27 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10971115
}];
10981116

10991117
let builders = [
1100-
OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
1101-
OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
1102-
OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
1118+
OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_asis, CArg<"bool", "true">:$hasOuterBound)>,
1119+
OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis, CArg<"bool", "true">:$hasOuterBound)>,
1120+
OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "true">:$hasOuterBound)>,
1121+
OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "true">:$hasOuterBound)>
11031122
];
11041123

11051124
let extraClassDeclaration = [{
1125+
/// Return true if the basis includes a bound on the first index input.
1126+
bool hasOuterBound() {
1127+
return getMultiIndex().size() == getStaticBasis().size();
1128+
}
1129+
11061130
/// Returns a vector with all the static and dynamic basis values.
11071131
SmallVector<OpFoldResult> getMixedBasis() {
11081132
OpBuilder builder(getContext());
11091133
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
11101134
}
1135+
1136+
/// Return a vector that contains the basis of the operation, removing
1137+
/// the outer bound if one is present.
1138+
SmallVector<OpFoldResult> getEffectiveBasis();
11111139
}];
11121140

11131141
let hasVerifier = 1;
@@ -1125,13 +1153,21 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
11251153
The `affine.linearize_index` operation takes a sequence of index values and a
11261154
basis of the same length and linearizes the indices using that basis.
11271155

1128-
That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
1129-
it computes
1156+
That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0`
1157+
(or `b_1`) up to `b_{N-1}` it computes
11301158

11311159
```
1132-
sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
1160+
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
11331161
```
11341162

1163+
The basis may either have `N` or `N-1` elements, where `N` is the number of
1164+
inputs to linearize_index. If `N` inputs are provided, the first one is not used
1165+
in computation, but may be used during analysis or canonicalization as a bound
1166+
on `%idx_0`.
1167+
1168+
If all `N` basis elements are provided, the linearize_index operation is said to
1169+
"have an outer bound".
1170+
11351171
If the `disjoint` property is present, this is an optimization hint that,
11361172
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
11371173
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1141,7 +1177,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
11411177
Example:
11421178

11431179
```mlir
1144-
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
1180+
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index
1181+
// Same effect
1182+
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index
11451183
```
11461184

11471185
In the above example, `%linear_index` conceptually holds the following:
@@ -1172,12 +1210,20 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
11721210
];
11731211

11741212
let extraClassDeclaration = [{
1213+
/// Return true if the basis includes a bound on the first index input.
1214+
bool hasOuterBound() {
1215+
return getMultiIndex().size() == getStaticBasis().size();
1216+
}
1217+
11751218
/// Return a vector with all the static and dynamic basis values.
11761219
SmallVector<OpFoldResult> getMixedBasis() {
11771220
OpBuilder builder(getContext());
11781221
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
11791222
}
11801223

1224+
/// Return a vector that contains the basis of the operation, removing
1225+
/// the outer bound if one is present.
1226+
SmallVector<OpFoldResult> getEffectiveBasis();
11811227
}];
11821228

11831229
let hasVerifier = 1;

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,23 @@ struct DivModValue {
307307
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
308308

309309
/// Generate the IR to delinearize `linearIndex` given the `basis` and return
310-
/// the multi-index.
310+
/// the multi-index. `hasOuterBound` indicates whether `basis` has an entry
311+
/// given the size of the first multi-index result - if it is true, the function
312+
/// will return `basis.size()` values, otherwise, it will return `basis.size() +
313+
/// 1`.
311314
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
312315
Value linearIndex,
313-
ArrayRef<Value> basis);
316+
ArrayRef<Value> basis,
317+
bool hasOuterBound = true);
314318

315319
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
316320
Value linearIndex,
317-
ArrayRef<OpFoldResult> basis);
321+
ArrayRef<OpFoldResult> basis,
322+
bool hasOuterBound = true);
318323

319324
// Generate IR that extracts the linear index from a multi-index according to
320-
// a basis/shape.
325+
// a basis/shape. The basis may contain either `multiIndex.size()` or
326+
// `multiIndex.size() - 1` elements.
321327
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
322328
ArrayRef<OpFoldResult> basis,
323329
ImplicitLocOpBuilder &builder);

0 commit comments

Comments
 (0)