Skip to content

[mlir][affine] Define affine.linearize_index #114480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1113,4 +1113,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// AffineLinearizeIndexOp
//===----------------------------------------------------------------------===//
def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
[Pure, AttrSizedOperandSegments]> {
let summary = "linearize an index";
let description = [{
The `affine.linearize_index` operation takes a sequence of index values and a
basis of the same length and linearizes the indices using that basis.

That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
it computes

```
sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
```

If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.

Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.

Example:

```mlir
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
```

In the above example, `%linear_index` conceptually holds the following:

```mlir
#map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
%linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
```
}];

let arguments = (ins Variadic<Index>:$multi_index,
Variadic<Index>:$dynamic_basis,
DenseI64ArrayAttr:$static_basis,
UnitProperty:$disjoint);
let results = (outs Index:$linear_index);

let assemblyFormat = [{
(`disjoint` $disjoint^)? ` `
`[` $multi_index `]` `by` ` `
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
attr-dict `:` type($linear_index)
}];

let builders = [
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
];

let extraClassDeclaration = [{
/// Return a vector with all the static and dynamic basis values.
SmallVector<OpFoldResult> getMixedBasis() {
OpBuilder builder(getContext());
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}

}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

#endif // AFFINE_OPS
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,17 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<OpFoldResult> basis);

// Generate IR that extracts the linear index from a multi-index according to
// a basis/shape.
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder);

OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis);

/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
/// between the operations) do not have the potential memory effect
Expand Down
109 changes: 109 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4684,6 +4684,115 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
}

//===----------------------------------------------------------------------===//
// LinearizeIndexOp
//===----------------------------------------------------------------------===//

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
staticBasis);
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
}

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
}

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex,
ArrayRef<int64_t> basis, bool disjoint) {
build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
}

LogicalResult AffineLinearizeIndexOp::verify() {
if (getStaticBasis().empty())
return emitOpError("basis should not be empty");

if (getMultiIndex().size() != getStaticBasis().size())
return emitOpError("should be passed an index for each basis element");

auto dynamicMarkersCount =
llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
return emitOpError(
"mismatch between dynamic and static basis (kDynamic marker but no "
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");

return success();
}

namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
/// %...d)`.

/// Note that `disjoint` is required here, because, without it, we could have
/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
/// is a valid operation where the `%c64` cannot be trivially dropped.
///
/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
/// the operation isn't asserted to be `disjoint`.
struct DropLinearizeUnitComponentsIfDisjointOrZero final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
size_t numIndices = op.getMultiIndex().size();
SmallVector<Value> newIndices;
newIndices.reserve(numIndices);
SmallVector<OpFoldResult> newBasis;
newBasis.reserve(numIndices);

SmallVector<OpFoldResult> basis = op.getMixedBasis();
for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
if (!basisEntry || *basisEntry != 1) {
newIndices.push_back(index);
newBasis.push_back(basisElem);
continue;
}

std::optional<int64_t> indexValue = getConstantIntValue(index);
if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
newIndices.push_back(index);
newBasis.push_back(basisElem);
continue;
}
}
if (newIndices.size() == numIndices)
return failure();

if (newIndices.size() == 0) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
op, newIndices, newBasis, op.getDisjoint());
return success();
}
};
} // namespace

void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
Expand Down Expand Up @@ -44,6 +45,23 @@ struct LowerDelinearizeIndexOps
}
};

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> multiIndex =
getAsOpFoldResult(op.getMultiIndex());
OpFoldResult linearIndex =
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
Value linearIndexValue =
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
rewriter.replaceOp(op, linearIndexValue);
return success();
}
};

class ExpandAffineIndexOpsPass
: public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
public:
Expand All @@ -63,7 +81,8 @@ class ExpandAffineIndexOpsPass

void mlir::affine::populateAffineExpandIndexOpsPatterns(
RewritePatternSet &patterns) {
patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
patterns.getContext());
}

std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
}

OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis) {
assert(multiIndex.size() == basis.size());
SmallVector<AffineExpr> basisAffine;
for (size_t i = 0; i < basis.size(); ++i) {
Expand All @@ -2009,13 +2015,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
SmallVector<OpFoldResult> strides;
strides.reserve(stridesAffine.size());
llvm::transform(stridesAffine, std::back_inserter(strides),
[&builder, &basis](AffineExpr strideExpr) {
[&builder, &basis, loc](AffineExpr strideExpr) {
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), strideExpr, basis);
builder, loc, strideExpr, basis);
});

auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
multiIndexAndStrides);
}
17 changes: 17 additions & 0 deletions mlir/test/Conversion/AffineToStandard/lower-affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -976,3 +976,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index
// CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index
// CHECK: }

/////////////////////////////////////////////////////////////////////

func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
return %ret : index
}

// CHECK-LABEL: @test_linearize_index
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[c15:.+]] = arith.constant 15 : index
// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
// CHECK-NEXT: return %[[ret]]
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
%1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}

// -----

// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>

// CHECK-LABEL: @linearize_static
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
// CHECK: return %[[val_0]]
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
func.return %0 : index
}

// -----

// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>

// CHECK-LABEL: @linearize_dynamic
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
// CHECK: return %[[val_0]]
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious if we even need %arg3. Would it be used if disjoint is false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it - but we also don't need the first basis element on delinearize_index either.

This is kept for symmetry with memref.load and its friends

func.return %0 : index
}
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1530,3 +1530,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
%2 = affine.delinearize_index %i into (1024) : index
return %2 : index
}

// -----

// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
// CHECK: return %[[ret]]
func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
return %ret : index
}

// -----

// CHECK-LABEL: @linearize_unit_basis_zero
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
// CHECK: return %[[ret]]
func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
%c0 = arith.constant 0 : index
%ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
return %ret : index
}

// -----

// CHECK-LABEL: @linearize_all_zero_unit_basis
// CHECK: arith.constant 0 : index
// CHECK-NOT: affine.linearize_index
func.func @linearize_all_zero_unit_basis() -> index {
%c0 = arith.constant 0 : index
%ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
return %ret : index
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,22 @@ func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {

// -----

func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index {
// expected-error@+1 {{'affine.linearize_index' op should be passed an index for each basis element}}
%0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index
return %0 : index
}

// -----

func.func @linearize_empty() -> index {
// expected-error@+1 {{'affine.linearize_index' op basis should not be empty}}
%0 = affine.linearize_index [] by () : index
return %0 : index
}

// -----

func.func @dynamic_dimension_index() {
"unknown.region"() ({
%idx = "unknown.test"() : () -> (index)
Expand Down
Loading
Loading