Skip to content

[mlir][sparse] Change LevelType enum to 64 bit #80501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorLevelType {
typedef uint64_t MlirSparseTensorLevelType;

enum MlirBaseSparseTensorLevelType {
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
Expand Down Expand Up @@ -53,15 +55,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
MlirContext ctx, intptr_t lvlRank,
enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
MlirAffineMap lvlTodim, int posWidth, int crdWidth);

/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
MLIR_CAPI_EXPORTED intptr_t
mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);

/// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);

/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
Expand Down
38 changes: 19 additions & 19 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
enum class LevelType : uint8_t {
enum class LevelType : uint64_t {
Undef = 0, // 0b00000_00
Dense = 4, // 0b00001_00
Compressed = 8, // 0b00010_00
Expand All @@ -184,7 +184,7 @@ enum class LevelType : uint8_t {
};

/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint8_t {
enum class LevelFormat : uint64_t {
Dense = 4, // 0b00001_00
Compressed = 8, // 0b00010_00
Singleton = 16, // 0b00100_00
Expand All @@ -193,7 +193,7 @@ enum class LevelFormat : uint8_t {
};

/// This enum defines all the nondefault properties for storage formats.
enum class LevelPropertyNondefault : uint8_t {
enum class LevelPropertyNondefault : uint64_t {
Nonunique = 1, // 0b00000_01
Nonordered = 2, // 0b00000_10
};
Expand Down Expand Up @@ -237,8 +237,8 @@ constexpr const char *toMLIRString(LevelType lt) {

/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
// If undefined or dense, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
return (formatBits <= 1 || formatBits == 16)
Expand All @@ -251,32 +251,32 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }

/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
return (static_cast<uint8_t>(lt) & ~3) ==
static_cast<uint8_t>(LevelType::Dense);
return (static_cast<uint64_t>(lt) & ~3) ==
static_cast<uint64_t>(LevelType::Dense);
}

/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
return (static_cast<uint8_t>(lt) & ~3) ==
static_cast<uint8_t>(LevelType::Compressed);
return (static_cast<uint64_t>(lt) & ~3) ==
static_cast<uint64_t>(LevelType::Compressed);
}

/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
return (static_cast<uint8_t>(lt) & ~3) ==
static_cast<uint8_t>(LevelType::Singleton);
return (static_cast<uint64_t>(lt) & ~3) ==
static_cast<uint64_t>(LevelType::Singleton);
}

/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
return (static_cast<uint8_t>(lt) & ~3) ==
static_cast<uint8_t>(LevelType::LooseCompressed);
return (static_cast<uint64_t>(lt) & ~3) ==
static_cast<uint64_t>(LevelType::LooseCompressed);
}

/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
constexpr bool is2OutOf4LT(LevelType lt) {
return (static_cast<uint8_t>(lt) & ~3) ==
static_cast<uint8_t>(LevelType::TwoOutOfFour);
return (static_cast<uint64_t>(lt) & ~3) ==
static_cast<uint64_t>(LevelType::TwoOutOfFour);
}

/// Check if the `LevelType` needs positions array.
Expand All @@ -292,28 +292,28 @@ constexpr bool isWithCrdLT(LevelType lt) {

/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
return !(static_cast<uint8_t>(lt) & 2);
return !(static_cast<uint64_t>(lt) & 2);
}

/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
return !(static_cast<uint8_t>(lt) & 1);
return !(static_cast<uint64_t>(lt) & 1);
}

/// Convert a LevelType to its corresponding LevelFormat.
/// Returns std::nullopt when input lt is Undef.
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
}

/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
bool unique) {
auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
(ordered ? 0 : 2) | (unique ? 0 : 1));
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Bindings/Python/DialectSparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace mlir;
using namespace mlir::python::adaptors;

static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local())
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
.value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
/// Generates a constant of the internal dimension level type encoding.
inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
LevelType lt) {
return constantI8(builder, loc, static_cast<uint8_t>(lt));
return constantI64(builder, loc, static_cast<uint64_t>(lt));
}

inline bool isZeroRankedTensorOrScalar(Type type) {
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/CAPI/sparse_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ static int testRoundtripEncoding(MlirContext ctx) {
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
enum MlirSparseTensorLevelType *lvlTypes =
malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
MlirSparseTensorLevelType *lvlTypes =
malloc(sizeof(MlirSparseTensorLevelType) * lvlRank);
for (int l = 0; l < lvlRank; ++l) {
lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
}
// CHECK: posWidth: 32
int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
Expand All @@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
Expand All @@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
Expand All @@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
// CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
// CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
Expand All @@ -28,7 +28,7 @@
// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
// CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/python/dialects/sparse_tensor/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [<LevelType.compressed: 8>]
# CHECK: lvl_types: [8]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down Expand Up @@ -70,7 +70,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
# CHECK: lvl_types: [4, 8]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down