Skip to content

[mlir][affine] Define affine.linearize_index #114480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 5, 2024

Conversation

krzysz00
Copy link
Contributor

affine.linearize_index is the inverse of affine.delinearize_index and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices.

This commit introduces affine.linearize_index and one simple canonicalization for it.

There are plans to add affine.linearize_index and affine.delinearize_index pair canonicalizations, but we are saving those for a followup PR (especially since having #113846 landed would make them nicer).

Note while affine may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location.

`affine.linearize_index` is the inverse of `affine.delinearize_index`
and general useful for representing computations (like those needed to
move from N-D to 1-D memrefs) that put together indices.

This commit introduces `affine.linearize_index` and one simple
canonicalization for it.

There are plans to add `affine.linearize_index` and
`affine.delinearize_index` pair canonicalizations, but we are saving
those for a followup PR (especially since having llvm#113846 landed would
make them nicer).

Note while `affine` may not be the natural home for this operation,
https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13
didn't come to any better consensus location.
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

affine.linearize_index is the inverse of affine.delinearize_index and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices.

This commit introduces affine.linearize_index and one simple canonicalization for it.

There are plans to add affine.linearize_index and affine.delinearize_index pair canonicalizations, but we are saving those for a followup PR (especially since having #113846 landed would make them nicer).

Note while affine may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location.


Patch is 21.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114480.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+69)
  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+3)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+16)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+106)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+21-1)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+10-4)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+17)
  • (modified) mlir/test/Dialect/Affine/affine-expand-index-ops.mlir (+26)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+34)
  • (modified) mlir/test/Dialect/Affine/invalid.mlir (+16)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..40ce26d6527cc0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..c49aa8281f49d1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1099,4 +1099,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AffineLinearizeIndexOp
+//===----------------------------------------------------------------------===//
+def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
+    [Pure, AttrSizedOperandSegments]> {
+  let summary = "linearize an index";
+  let description = [{
+    The `affine.linearize_index` operation takes a sequence of index values and a
+    basis of the same length and linearizes the indices using that basis.
+
+    That is, for indices %idx_1 through %i_N and basis elements b_1 through b_N,
+    it computes
+
+    ```
+    sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
+    ```
+
+    If the `disjoint` property is present, this is an optimization hint that,
+    for all i, 0 <= %idx_i < B_i - that is, no index affects any other index,
+    except that %idx_0 may be negative to make the index as a whole negative.
+
+    Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
+
+    Example:
+
+    ```
+    %linear_index = affine.delinearize_index [%index_0, %index_1, %index_2] (16, 224, 224) : index
+    ```
+
+    In the above example, `%linear_index` conceptually holds the following:
+
+    ```
+    #map = affine_map<()[s0, s1, s2] -> (s0 * 50176 + s1 * 224 + s2)>
+    %linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
+    ```
+  }];
+
+  let arguments = (ins Variadic<Index>:$multi_index,
+    Variadic<Index>:$dynamic_basis,
+    DenseI64ArrayAttr:$static_basis,
+    UnitProperty:$disjoint);
+  let results = (outs Index:$linear_index);
+
+  let assemblyFormat = [{
+    (`disjoint` $disjoint^)? ` `
+    `[` $multi_index `]` `by` ` `
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($linear_index)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
+    OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
+    OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic basis values.
+    SmallVector<OpFoldResult> getMixedBasis() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
+
+  }];
+
+  let hasVerifier = 1;
+  let hasCanonicalizer = 1;
+}
+
 #endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..f518272ec0e8f9 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -316,6 +316,9 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                             ArrayRef<OpFoldResult> basis,
                             ImplicitLocOpBuilder &builder);
+OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
+                            ArrayRef<OpFoldResult> multiIndex,
+                            ArrayRef<OpFoldResult> basis);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values,
+                                  ArrayRef<int64_t> integers,
+                                  AsmParser::Delimiter delimiter) {
+  return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+                               delimiter);
+}
 inline void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      DenseI64ArrayAttr &integers,
+                      AsmParser::Delimiter delimiter) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+                               delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..2d3b5a80df4b0e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4664,6 +4664,112 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
   patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// LinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex, ValueRange basis,
+                                   bool disjoint) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+                             staticBasis);
+  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+}
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex,
+                                   ArrayRef<OpFoldResult> basis,
+                                   bool disjoint) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+}
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex,
+                                   ArrayRef<int64_t> basis, bool disjoint) {
+  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
+}
+
+LogicalResult AffineLinearizeIndexOp::verify() {
+  if (getStaticBasis().empty())
+    return emitOpError("basis should not be empty");
+  if (getMultiIndex().size() != getStaticBasis().size())
+    return emitOpError("should be passed an index for each basis element");
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+    return emitOpError(
+        "mismatch between dynamic and static basis (kDynamic marker but no "
+        "corresponding dynamic basis entry) -- this can only happen due to an "
+        "incorrect fold/rewrite");
+  return success();
+}
+
+namespace {
+/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
+/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
+/// %...d)`.
+
+/// Note that `disjoint` is required here, because, without it, we could have
+/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
+/// is a valid operation where the `%c64` cannot be trivially dropped.
+///
+/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
+/// the operation isn't asserted to be `disjoint`.
+struct DropLinearizeUnitComponentsIfDisjointOrZero
+    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    size_t numIndices = op.getMultiIndex().size();
+    SmallVector<Value> newIndices;
+    newIndices.reserve(numIndices);
+    SmallVector<OpFoldResult> newBasis;
+    newBasis.reserve(numIndices);
+
+    SmallVector<OpFoldResult> basis = op.getMixedBasis();
+    for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
+      std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
+      if (!basisEntry || *basisEntry != 1) {
+        newIndices.push_back(index);
+        newBasis.push_back(basisElem);
+        continue;
+      }
+
+      std::optional<int64_t> indexValue = getConstantIntValue(index);
+      if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
+        newIndices.push_back(index);
+        newBasis.push_back(basisElem);
+        continue;
+      }
+    }
+    if (newIndices.size() == numIndices)
+      return failure();
+
+    if (newIndices.size() == 0) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      return success();
+    }
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, newIndices, newBasis, op.getDisjoint());
+    return success();
+  }
+};
+} // namespace
+
+void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..c96c188cbc89fc 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -45,6 +46,24 @@ struct LowerDelinearizeIndexOps
   }
 };
 
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions.
+struct LowerLinearizeIndexOps
+    : public OpRewritePattern<AffineLinearizeIndexOp> {
+  using OpRewritePattern<AffineLinearizeIndexOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> multiIndex =
+        getAsOpFoldResult(op.getMultiIndex());
+    OpFoldResult linearIndex =
+        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+    Value linearIndexValue =
+        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+    rewriter.replaceOp(op, linearIndexValue);
+    return success();
+  }
+};
+
 class ExpandAffineIndexOpsPass
     : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
 public:
@@ -64,7 +83,8 @@ class ExpandAffineIndexOpsPass
 
 void mlir::affine::populateAffineExpandIndexOpsPatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
+  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
+      patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..76b67555feb132 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1973,6 +1973,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis,
                                           ImplicitLocOpBuilder &builder) {
+  return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
+}
+
+OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
+                                          ArrayRef<OpFoldResult> multiIndex,
+                                          ArrayRef<OpFoldResult> basis) {
   assert(multiIndex.size() == basis.size());
   SmallVector<AffineExpr> basisAffine;
   for (size_t i = 0; i < basis.size(); ++i) {
@@ -1983,13 +1989,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
   SmallVector<OpFoldResult> strides;
   strides.reserve(stridesAffine.size());
   llvm::transform(stridesAffine, std::back_inserter(strides),
-                  [&builder, &basis](AffineExpr strideExpr) {
+                  [&builder, &basis, loc](AffineExpr strideExpr) {
                     return affine::makeComposedFoldedAffineApply(
-                        builder, builder.getLoc(), strideExpr, basis);
+                        builder, loc, strideExpr, basis);
                   });
 
   auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
       OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
-  return affine::makeComposedFoldedAffineApply(
-      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+  return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
+                                               multiIndexAndStrides);
 }
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..7e1c25355c3086 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -981,3 +981,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
 // CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
 // CHECK:           return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
 // CHECK:         }
+
+/////////////////////////////////////////////////////////////////////
+
+func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// CHECK-LABEL: @test_linearize_index
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[c15:.+]] = arith.constant 15 : index
+// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
+// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
+// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
+// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
+// CHECK-NEXT: return %[[ret]]
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 70b7f397ad4fec..fb20ba2ee69e01 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -44,3 +44,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
   %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
+
+// CHECK-LABEL: @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+  func.return %0 : index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
+
+// CHECK-LABEL: @linearize_dynamic
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
+  func.return %0 : index
+}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 906ae81c76d115..5424e082902609 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1533,3 +1533,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
   %2 = affine.delinearize_index %i into (%c1024) : index
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: @linearize_unit_basis_disjoint
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_unit_basis_zero
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_all_zero_unit_basis
+// CHECK: arith.constant 0 : index
+// CHECK-NOT: affine.linearize_index
+func.func @linearize_all_zero_unit_basis() -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
+  return %ret : index
+}
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 869ea712bb3690..2996194170900f 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -548,6 +548,22 @@ func.func @delinearize(%idx: index, %basis0: index, %b...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good, just some cosmetic issues

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good. I just have one clarifying question.

// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
// CHECK: return %[[val_0]]
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious if we even need %arg3. Would it be used if disjoint is false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it - but we also don't need the first basis element on delinearize_index either.

This is kept for symmetry with memref.load and its friends

Copy link

github-actions bot commented Nov 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@krzysz00 krzysz00 merged commit 847d507 into llvm:main Nov 5, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants