Skip to content

Commit 847d507

Browse files
krzysz00kuhar
andauthored
[mlir][affine] Define affine.linearize_index (#114480)
`affine.linearize_index` is the inverse of `affine.delinearize_index` and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices. This commit introduces `affine.linearize_index` and one simple canonicalization for it. There are plans to add `affine.linearize_index` and `affine.delinearize_index` pair canonicalizations, but we are saving those for a followup PR (especially since having #113846 landed would make them nicer). Note while `affine` may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent cb9700e commit 847d507

File tree

10 files changed

+322
-5
lines changed

10 files changed

+322
-5
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,4 +1113,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
11131113
let hasCanonicalizer = 1;
11141114
}
11151115

1116+
//===----------------------------------------------------------------------===//
1117+
// AffineLinearizeIndexOp
1118+
//===----------------------------------------------------------------------===//
1119+
def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
1120+
[Pure, AttrSizedOperandSegments]> {
1121+
let summary = "linearize an index";
1122+
let description = [{
1123+
The `affine.linearize_index` operation takes a sequence of index values and a
1124+
basis of the same length and linearizes the indices using that basis.
1125+
1126+
That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
1127+
it computes
1128+
1129+
```
1130+
sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
1131+
```
1132+
1133+
If the `disjoint` property is present, this is an optimization hint that,
1134+
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
1135+
except that `%idx_0` may be negative to make the index as a whole negative.
1136+
1137+
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
1138+
1139+
Example:
1140+
1141+
```mlir
1142+
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
1143+
```
1144+
1145+
In the above example, `%linear_index` conceptually holds the following:
1146+
1147+
```mlir
1148+
#map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
1149+
%linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
1150+
```
1151+
}];
1152+
1153+
let arguments = (ins Variadic<Index>:$multi_index,
1154+
Variadic<Index>:$dynamic_basis,
1155+
DenseI64ArrayAttr:$static_basis,
1156+
UnitProperty:$disjoint);
1157+
let results = (outs Index:$linear_index);
1158+
1159+
let assemblyFormat = [{
1160+
(`disjoint` $disjoint^)? ` `
1161+
`[` $multi_index `]` `by` ` `
1162+
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
1163+
attr-dict `:` type($linear_index)
1164+
}];
1165+
1166+
let builders = [
1167+
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
1168+
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
1169+
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
1170+
];
1171+
1172+
let extraClassDeclaration = [{
1173+
/// Return a vector with all the static and dynamic basis values.
1174+
SmallVector<OpFoldResult> getMixedBasis() {
1175+
OpBuilder builder(getContext());
1176+
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
1177+
}
1178+
1179+
}];
1180+
1181+
let hasVerifier = 1;
1182+
let hasCanonicalizer = 1;
1183+
}
1184+
11161185
#endif // AFFINE_OPS

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,17 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
315315
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
316316
Value linearIndex,
317317
ArrayRef<OpFoldResult> basis);
318+
318319
// Generate IR that extracts the linear index from a multi-index according to
319320
// a basis/shape.
320321
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
321322
ArrayRef<OpFoldResult> basis,
322323
ImplicitLocOpBuilder &builder);
323324

325+
OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
326+
ArrayRef<OpFoldResult> multiIndex,
327+
ArrayRef<OpFoldResult> basis);
328+
324329
/// Ensure that all operations that could be executed after `start`
325330
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
326331
/// between the operations) do not have the potential memory effect

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4684,6 +4684,115 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
46844684
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
46854685
}
46864686

4687+
//===----------------------------------------------------------------------===//
4688+
// LinearizeIndexOp
4689+
//===----------------------------------------------------------------------===//
4690+
4691+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4692+
OperationState &odsState,
4693+
ValueRange multiIndex, ValueRange basis,
4694+
bool disjoint) {
4695+
SmallVector<Value> dynamicBasis;
4696+
SmallVector<int64_t> staticBasis;
4697+
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4698+
staticBasis);
4699+
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4700+
}
4701+
4702+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4703+
OperationState &odsState,
4704+
ValueRange multiIndex,
4705+
ArrayRef<OpFoldResult> basis,
4706+
bool disjoint) {
4707+
SmallVector<Value> dynamicBasis;
4708+
SmallVector<int64_t> staticBasis;
4709+
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4710+
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4711+
}
4712+
4713+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4714+
OperationState &odsState,
4715+
ValueRange multiIndex,
4716+
ArrayRef<int64_t> basis, bool disjoint) {
4717+
build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4718+
}
4719+
4720+
LogicalResult AffineLinearizeIndexOp::verify() {
4721+
if (getStaticBasis().empty())
4722+
return emitOpError("basis should not be empty");
4723+
4724+
if (getMultiIndex().size() != getStaticBasis().size())
4725+
return emitOpError("should be passed an index for each basis element");
4726+
4727+
auto dynamicMarkersCount =
4728+
llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4729+
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4730+
return emitOpError(
4731+
"mismatch between dynamic and static basis (kDynamic marker but no "
4732+
"corresponding dynamic basis entry) -- this can only happen due to an "
4733+
"incorrect fold/rewrite");
4734+
4735+
return success();
4736+
}
4737+
4738+
namespace {
4739+
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
4740+
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
4741+
/// %...d)`.
4742+
4743+
/// Note that `disjoint` is required here, because, without it, we could have
4744+
/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
4745+
/// is a valid operation where the `%c64` cannot be trivially dropped.
4746+
///
4747+
/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
4748+
/// the operation isn't asserted to be `disjoint`.
4749+
struct DropLinearizeUnitComponentsIfDisjointOrZero final
4750+
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
4751+
using OpRewritePattern::OpRewritePattern;
4752+
4753+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
4754+
PatternRewriter &rewriter) const override {
4755+
size_t numIndices = op.getMultiIndex().size();
4756+
SmallVector<Value> newIndices;
4757+
newIndices.reserve(numIndices);
4758+
SmallVector<OpFoldResult> newBasis;
4759+
newBasis.reserve(numIndices);
4760+
4761+
SmallVector<OpFoldResult> basis = op.getMixedBasis();
4762+
for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
4763+
std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
4764+
if (!basisEntry || *basisEntry != 1) {
4765+
newIndices.push_back(index);
4766+
newBasis.push_back(basisElem);
4767+
continue;
4768+
}
4769+
4770+
std::optional<int64_t> indexValue = getConstantIntValue(index);
4771+
if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
4772+
newIndices.push_back(index);
4773+
newBasis.push_back(basisElem);
4774+
continue;
4775+
}
4776+
}
4777+
if (newIndices.size() == numIndices)
4778+
return failure();
4779+
4780+
if (newIndices.size() == 0) {
4781+
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
4782+
return success();
4783+
}
4784+
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
4785+
op, newIndices, newBasis, op.getDisjoint());
4786+
return success();
4787+
}
4788+
};
4789+
} // namespace
4790+
4791+
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
4792+
RewritePatternSet &patterns, MLIRContext *context) {
4793+
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4794+
}
4795+
46874796
//===----------------------------------------------------------------------===//
46884797
// TableGen'd op method definitions
46894798
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1717
#include "mlir/Dialect/Affine/Utils.h"
18+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920

2021
namespace mlir {
@@ -44,6 +45,23 @@ struct LowerDelinearizeIndexOps
4445
}
4546
};
4647

48+
/// Lowers `affine.linearize_index` into a sequence of multiplications and
49+
/// additions.
50+
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
51+
using OpRewritePattern::OpRewritePattern;
52+
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
53+
PatternRewriter &rewriter) const override {
54+
SmallVector<OpFoldResult> multiIndex =
55+
getAsOpFoldResult(op.getMultiIndex());
56+
OpFoldResult linearIndex =
57+
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
58+
Value linearIndexValue =
59+
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
60+
rewriter.replaceOp(op, linearIndexValue);
61+
return success();
62+
}
63+
};
64+
4765
class ExpandAffineIndexOpsPass
4866
: public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
4967
public:
@@ -63,7 +81,8 @@ class ExpandAffineIndexOpsPass
6381

6482
void mlir::affine::populateAffineExpandIndexOpsPatterns(
6583
RewritePatternSet &patterns) {
66-
patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
84+
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
85+
patterns.getContext());
6786
}
6887

6988
std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
19991999
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
20002000
ArrayRef<OpFoldResult> basis,
20012001
ImplicitLocOpBuilder &builder) {
2002+
return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
2003+
}
2004+
2005+
OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
2006+
ArrayRef<OpFoldResult> multiIndex,
2007+
ArrayRef<OpFoldResult> basis) {
20022008
assert(multiIndex.size() == basis.size());
20032009
SmallVector<AffineExpr> basisAffine;
20042010
for (size_t i = 0; i < basis.size(); ++i) {
@@ -2009,13 +2015,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
20092015
SmallVector<OpFoldResult> strides;
20102016
strides.reserve(stridesAffine.size());
20112017
llvm::transform(stridesAffine, std::back_inserter(strides),
2012-
[&builder, &basis](AffineExpr strideExpr) {
2018+
[&builder, &basis, loc](AffineExpr strideExpr) {
20132019
return affine::makeComposedFoldedAffineApply(
2014-
builder, builder.getLoc(), strideExpr, basis);
2020+
builder, loc, strideExpr, basis);
20152021
});
20162022

20172023
auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
20182024
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
2019-
return affine::makeComposedFoldedAffineApply(
2020-
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
2025+
return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
2026+
multiIndexAndStrides);
20212027
}

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,3 +976,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
976976
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index
977977
// CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index
978978
// CHECK: }
979+
980+
/////////////////////////////////////////////////////////////////////
981+
982+
func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
983+
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
984+
return %ret : index
985+
}
986+
987+
// CHECK-LABEL: @test_linearize_index
988+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
989+
// CHECK: %[[c15:.+]] = arith.constant 15 : index
990+
// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
991+
// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
992+
// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
993+
// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
994+
// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
995+
// CHECK-NEXT: return %[[ret]]

mlir/test/Dialect/Affine/affine-expand-index-ops.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
4141
%1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
4242
return %1#0, %1#1, %1#2 : index, index, index
4343
}
44+
45+
// -----
46+
47+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
48+
49+
// CHECK-LABEL: @linearize_static
50+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
51+
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
52+
// CHECK: return %[[val_0]]
53+
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
54+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
55+
func.return %0 : index
56+
}
57+
58+
// -----
59+
60+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
61+
62+
// CHECK-LABEL: @linearize_dynamic
63+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
64+
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
65+
// CHECK: return %[[val_0]]
66+
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
67+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
68+
func.return %0 : index
69+
}

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,3 +1530,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
15301530
%2 = affine.delinearize_index %i into (1024) : index
15311531
return %2 : index
15321532
}
1533+
1534+
// -----
1535+
1536+
// CHECK-LABEL: @linearize_unit_basis_disjoint
1537+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
1538+
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
1539+
// CHECK: return %[[ret]]
1540+
func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
1541+
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
1542+
return %ret : index
1543+
}
1544+
1545+
// -----
1546+
1547+
// CHECK-LABEL: @linearize_unit_basis_zero
1548+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
1549+
// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
1550+
// CHECK: return %[[ret]]
1551+
func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
1552+
%c0 = arith.constant 0 : index
1553+
%ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
1554+
return %ret : index
1555+
}
1556+
1557+
// -----
1558+
1559+
// CHECK-LABEL: @linearize_all_zero_unit_basis
1560+
// CHECK: arith.constant 0 : index
1561+
// CHECK-NOT: affine.linearize_index
1562+
func.func @linearize_all_zero_unit_basis() -> index {
1563+
%c0 = arith.constant 0 : index
1564+
%ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
1565+
return %ret : index
1566+
}

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,22 @@ func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
548548

549549
// -----
550550

551+
func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index {
552+
// expected-error@+1 {{'affine.linearize_index' op should be passed an index for each basis element}}
553+
%0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index
554+
return %0 : index
555+
}
556+
557+
// -----
558+
559+
func.func @linearize_empty() -> index {
560+
// expected-error@+1 {{'affine.linearize_index' op basis should not be empty}}
561+
%0 = affine.linearize_index [] by () : index
562+
return %0 : index
563+
}
564+
565+
// -----
566+
551567
func.func @dynamic_dimension_index() {
552568
"unknown.region"() ({
553569
%idx = "unknown.test"() : () -> (index)

0 commit comments

Comments
 (0)