Skip to content

[mlir][sparse] Introduce batch level format. #83082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;

enum MlirSparseTensorLevelFormat {
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
};

enum MlirSparseTensorLevelPropertyNondefault {
Expand Down
29 changes: 23 additions & 6 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,26 @@ enum class Action : uint32_t {
enum class LevelFormat : uint64_t {
Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
LooseCompressed = 0x00080000,
NOutOfM = 0x00100000,
Batch = 0x00020000,
Compressed = 0x00040000,
Singleton = 0x00080000,
LooseCompressed = 0x00100000,
NOutOfM = 0x00200000,
};

constexpr bool encPowOfTwo(LevelFormat fmt) {
auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
return (enc & (enc - 1)) == 0;
}

// All LevelFormats must have only one bit set (power of two).
static_assert(encPowOfTwo(LevelFormat::Dense) &&
encPowOfTwo(LevelFormat::Batch) &&
encPowOfTwo(LevelFormat::Compressed) &&
encPowOfTwo(LevelFormat::Singleton) &&
encPowOfTwo(LevelFormat::LooseCompressed) &&
encPowOfTwo(LevelFormat::NOutOfM));

template <LevelFormat... targets>
constexpr bool isAnyOfFmt(LevelFormat fmt) {
return (... || (targets == fmt));
Expand All @@ -172,6 +186,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
return "undef";
case LevelFormat::Dense:
return "dense";
case LevelFormat::Batch:
return "batch";
case LevelFormat::Compressed:
return "compressed";
case LevelFormat::Singleton:
Expand Down Expand Up @@ -225,10 +241,10 @@ struct LevelType {
static constexpr bool isValidLvlBits(uint64_t lvlBits) {
auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
const uint64_t propertyBits = lvlBits & 0xffff;
// If undefined/dense/NOutOfM, then must be unique and ordered.
// If undefined/dense/batch/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
LevelFormat::NOutOfM>(fmt))
LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
? (propertyBits == 0)
: (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed>(fmt));
Expand Down Expand Up @@ -375,6 +391,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
}
inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
inline bool isCompressedLT(LevelType lt) {
return lt.isa<LevelFormat::Compressed>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",

The supported level-formats are the following:

- **dense** : all entries along this level are stored
- **dense** : all entries along this level are stored and linearized.
- **batch** : all entries along this level are stored but not linearized.
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("batch") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Batch);
} else if (base.compare("compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
} else if (base.compare("structured") == 0) {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
}
}

auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
return emitError() << "Batch lvlType can only be leading levels.";

// SoA property can only be applied on singleton level.
auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
return lt.isa<LevelPropNonDefault::SoA>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
case LevelFormat::Batch:
llvm_unreachable("not implemented");
case LevelFormat::Compressed: {
Value pos = genToPositions(b, l, t, lvl);
Value crd = genToCoordinates(b, l, t, lvl);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/CAPI/sparse_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
// CHECK: level_type: 131072
// CHECK: level_type: 131072
// CHECK: level_type: 262144
// CHECK: level_type: 262144
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{Batch lvlType can only be leading levels}}
#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()

// -----

// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)

// -----

#BCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
}>

// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
// CHECK-LABEL: func private @sparse_bcsr(
// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)

// -----

#CSR_explicit = #sparse_tensor.encoding<{
map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/python/dialects/sparse_tensor/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [131072]
# CHECK: lvl_types: [262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down Expand Up @@ -71,9 +71,9 @@ def testEncodingAttrStructure():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [65536, 65536, 4406637494272]
# CHECK: lvl_types: [65536, 65536, 4406638542848]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
# CHECK: structured_n: 2
print(f"structured_n: {casted.structured_n}")
Expand Down Expand Up @@ -157,7 +157,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [65536, 131072]
# CHECK: lvl_types: [65536, 262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down