Skip to content

[mlir][sparse] Introduce batch level format. #83082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/83082.diff

11 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+5-4)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+23-5)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+2)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+2)
  • (modified) mlir/test/CAPI/sparse_tensor.c (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/invalid_encoding.mlir (+6)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+11)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+1-1)
  • (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+4-4)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 898d2f12779e39..52ca7ba8a1618f 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;
 
 enum MlirSparseTensorLevelFormat {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
-  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
+  MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
 };
 
 enum MlirSparseTensorLevelPropertyNondefault {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1c81d80ea7ec4e..c8404f10686307 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,12 +154,27 @@ enum class Action : uint32_t {
 enum class LevelFormat : uint64_t {
   Undef = 0x00000000,
   Dense = 0x00010000,
-  Compressed = 0x00020000,
-  Singleton = 0x00040000,
-  LooseCompressed = 0x00080000,
-  NOutOfM = 0x00100000,
+  Batch = 0x00020000,
+  Compressed = 0x00040000,
+  Singleton = 0x00080000,
+  LooseCompressed = 0x00100000,
+  NOutOfM = 0x00200000,
 };
 
+constexpr bool encPowOfTwo(LevelFormat fmt) {
+  auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
+  // http://www.graphics.stanford.edu/~seander/bithacks.html#DetermineIfPowerOf2
+  return (enc & (enc - 1)) == 0;
+}
+
+// All LevelFormat must have only one bit set (power of two).
+static_assert(encPowOfTwo(LevelFormat::Dense) &&
+              encPowOfTwo(LevelFormat::Batch) &&
+              encPowOfTwo(LevelFormat::Compressed) &&
+              encPowOfTwo(LevelFormat::Singleton) &&
+              encPowOfTwo(LevelFormat::LooseCompressed) &&
+              encPowOfTwo(LevelFormat::NOutOfM));
+
 template <LevelFormat... targets>
 constexpr bool isAnyOfFmt(LevelFormat fmt) {
   return (... || (targets == fmt));
@@ -172,6 +187,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
     return "undef";
   case LevelFormat::Dense:
     return "dense";
+  case LevelFormat::Batch:
+    return "batch";
   case LevelFormat::Compressed:
     return "compressed";
   case LevelFormat::Singleton:
@@ -228,7 +245,7 @@ struct LevelType {
     // If undefined/dense/NOutOfM, then must be unique and ordered.
     // Otherwise, the format must be one of the known ones.
     return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
-                       LevelFormat::NOutOfM>(fmt))
+                       LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
                ? (propertyBits == 0)
                : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
                              LevelFormat::LooseCompressed>(fmt));
@@ -375,6 +392,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
 }
 inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
 inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
+inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
 inline bool isCompressedLT(LevelType lt) {
   return lt.isa<LevelFormat::Compressed>();
 }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index f0b832571e68ec..ca98665256be5a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     The supported level-formats are the following:
 
-    - **dense** : all entries along this level are stored
+    - **dense** : all entries along this level are stored and linearized.
+    - **batch** : all entries along this level are stored but not linearized.
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 455e90baf0a715..92e5efaa810497 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   // Set the base bit for properties.
   if (base.compare("dense") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Dense);
+  } else if (base.compare("batch") == 0) {
+    properties |= static_cast<uint64_t>(LevelFormat::Batch);
   } else if (base.compare("compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Compressed);
   } else if (base.compare("structured") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index af7b85d458774d..fd0ed26fbde072 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
     }
   }
 
+  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
+    return emitError() << "Batch lvlType can only be leading levels.";
+
   // SoA property can only be applied on singleton level.
   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
     return lt.isa<LevelPropNonDefault::SoA>();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 61a3703b73bf07..011d814cd90094 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   switch (lt.getLvlFmt()) {
   case LevelFormat::Dense:
     return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
+  case LevelFormat::Batch:
+    llvm_unreachable("not implemented");
   case LevelFormat::Compressed: {
     Value pos = genToPositions(b, l, t, lvl);
     Value crd = genToCoordinates(b, l, t, lvl);
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index a8b9f9048d5912..f241e0e5c2fb56 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(dimToLvl);
   // CHECK: level_type: 65536
-  // CHECK: level_type: 131072
-  // CHECK: level_type: 131072
+  // CHECK: level_type: 262144
+  // CHECK: level_type: 262144
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 9ed3cee2591475..8096c010ac935a 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
+// expected-error@+1 {{Batch lvlType can only be leading levels}}
+#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
+func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()
+
+// -----
+
 // expected-error@+1 {{use of undeclared identifier}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
 func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 9d5118ceecc587..66e61afd897dd1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
 
+#BCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
+}>
+
+// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
+// CHECK-LABEL: func private @sparse_bcsr(
+// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
+func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)
+
+// -----
+
 #CSR_explicit = #sparse_tensor.encoding<{
   map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
 }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index d04fbe8ed5c220..6e8a26762d90fa 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 131072 : i64
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 262144 : i64
 // CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
 // CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
 // CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 2c0603216ef2c2..5666d090c3d5ee 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [131072]
+        # CHECK: lvl_types: [262144]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -71,9 +71,9 @@ def testEncodingAttrStructure():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [65536, 65536, 4406637494272]
+        # CHECK: lvl_types: [65536, 65536, 4406638542848]
         print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
+        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
         print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
         # CHECK: structured_n: 2
         print(f"structured_n: {casted.structured_n}")
@@ -157,7 +157,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [65536, 131072]
+        # CHECK: lvl_types: [65536, 262144]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")

@PeimingLiu PeimingLiu merged commit 56d5829 into llvm:main Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants