Skip to content

Commit cd481fa

Browse files
[mlir][sparse] Change LevelType enum to 64 bit (#80501)
1. C++ enum is set through enum class LevelType : uint_64. 2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C.
1 parent 0d09120 commit cd481fa

File tree

8 files changed

+45
-43
lines changed

8 files changed

+45
-43
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
2525
/// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
2626
/// If updating, keep them in sync and update the static_assert in the impl
2727
/// file.
28-
enum MlirSparseTensorLevelType {
28+
typedef uint64_t MlirSparseTensorLevelType;
29+
30+
enum MlirBaseSparseTensorLevelType {
2931
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
3032
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
3133
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
@@ -53,15 +55,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
5355
/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
5456
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
5557
MlirContext ctx, intptr_t lvlRank,
56-
enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
58+
MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
5759
MlirAffineMap lvlTodim, int posWidth, int crdWidth);
5860

5961
/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
6062
MLIR_CAPI_EXPORTED intptr_t
6163
mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
6264

6365
/// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
64-
MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
66+
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
6567
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
6668

6769
/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ enum class Action : uint32_t {
165165
/// where we need to store an undefined or indeterminate `LevelType`.
166166
/// It should not be used externally, since it does not indicate an
167167
/// actual/representable format.
168-
enum class LevelType : uint8_t {
168+
enum class LevelType : uint64_t {
169169
Undef = 0, // 0b00000_00
170170
Dense = 4, // 0b00001_00
171171
Compressed = 8, // 0b00010_00
@@ -184,7 +184,7 @@ enum class LevelType : uint8_t {
184184
};
185185

186186
/// This enum defines all supported storage format without the level properties.
187-
enum class LevelFormat : uint8_t {
187+
enum class LevelFormat : uint64_t {
188188
Dense = 4, // 0b00001_00
189189
Compressed = 8, // 0b00010_00
190190
Singleton = 16, // 0b00100_00
@@ -193,7 +193,7 @@ enum class LevelFormat : uint8_t {
193193
};
194194

195195
/// This enum defines all the nondefault properties for storage formats.
196-
enum class LevelPropertyNondefault : uint8_t {
196+
enum class LevelPropertyNondefault : uint64_t {
197197
Nonunique = 1, // 0b00000_01
198198
Nonordered = 2, // 0b00000_10
199199
};
@@ -237,8 +237,8 @@ constexpr const char *toMLIRString(LevelType lt) {
237237

238238
/// Check that the `LevelType` contains a valid (possibly undefined) value.
239239
constexpr bool isValidLT(LevelType lt) {
240-
const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
241-
const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
240+
const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
241+
const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
242242
// If undefined or dense, then must be unique and ordered.
243243
// Otherwise, the format must be one of the known ones.
244244
return (formatBits <= 1 || formatBits == 16)
@@ -251,32 +251,32 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
251251

252252
/// Check if the `LevelType` is dense (regardless of properties).
253253
constexpr bool isDenseLT(LevelType lt) {
254-
return (static_cast<uint8_t>(lt) & ~3) ==
255-
static_cast<uint8_t>(LevelType::Dense);
254+
return (static_cast<uint64_t>(lt) & ~3) ==
255+
static_cast<uint64_t>(LevelType::Dense);
256256
}
257257

258258
/// Check if the `LevelType` is compressed (regardless of properties).
259259
constexpr bool isCompressedLT(LevelType lt) {
260-
return (static_cast<uint8_t>(lt) & ~3) ==
261-
static_cast<uint8_t>(LevelType::Compressed);
260+
return (static_cast<uint64_t>(lt) & ~3) ==
261+
static_cast<uint64_t>(LevelType::Compressed);
262262
}
263263

264264
/// Check if the `LevelType` is singleton (regardless of properties).
265265
constexpr bool isSingletonLT(LevelType lt) {
266-
return (static_cast<uint8_t>(lt) & ~3) ==
267-
static_cast<uint8_t>(LevelType::Singleton);
266+
return (static_cast<uint64_t>(lt) & ~3) ==
267+
static_cast<uint64_t>(LevelType::Singleton);
268268
}
269269

270270
/// Check if the `LevelType` is loose compressed (regardless of properties).
271271
constexpr bool isLooseCompressedLT(LevelType lt) {
272-
return (static_cast<uint8_t>(lt) & ~3) ==
273-
static_cast<uint8_t>(LevelType::LooseCompressed);
272+
return (static_cast<uint64_t>(lt) & ~3) ==
273+
static_cast<uint64_t>(LevelType::LooseCompressed);
274274
}
275275

276276
/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
277277
constexpr bool is2OutOf4LT(LevelType lt) {
278-
return (static_cast<uint8_t>(lt) & ~3) ==
279-
static_cast<uint8_t>(LevelType::TwoOutOfFour);
278+
return (static_cast<uint64_t>(lt) & ~3) ==
279+
static_cast<uint64_t>(LevelType::TwoOutOfFour);
280280
}
281281

282282
/// Check if the `LevelType` needs positions array.
@@ -292,28 +292,28 @@ constexpr bool isWithCrdLT(LevelType lt) {
292292

293293
/// Check if the `LevelType` is ordered (regardless of storage format).
294294
constexpr bool isOrderedLT(LevelType lt) {
295-
return !(static_cast<uint8_t>(lt) & 2);
295+
return !(static_cast<uint64_t>(lt) & 2);
296296
}
297297

298298
/// Check if the `LevelType` is unique (regardless of storage format).
299299
constexpr bool isUniqueLT(LevelType lt) {
300-
return !(static_cast<uint8_t>(lt) & 1);
300+
return !(static_cast<uint64_t>(lt) & 1);
301301
}
302302

303303
/// Convert a LevelType to its corresponding LevelFormat.
304304
/// Returns std::nullopt when input lt is Undef.
305305
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
306306
if (lt == LevelType::Undef)
307307
return std::nullopt;
308-
return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
308+
return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
309309
}
310310

311311
/// Convert a LevelFormat to its corresponding LevelType with the given
312312
/// properties. Returns std::nullopt when the properties are not applicable
313313
/// for the input level format.
314314
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
315315
bool unique) {
316-
auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
316+
auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
317317
(ordered ? 0 : 2) | (unique ? 0 : 1));
318318
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
319319
}

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using namespace mlir;
2323
using namespace mlir::python::adaptors;
2424

2525
static void populateDialectSparseTensorSubmodule(const py::module &m) {
26-
py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local())
26+
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
2727
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
2828
.value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
2929
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
423423
/// Generates a constant of the internal dimension level type encoding.
424424
inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
425425
LevelType lt) {
426-
return constantI8(builder, loc, static_cast<uint8_t>(lt));
426+
return constantI64(builder, loc, static_cast<uint64_t>(lt));
427427
}
428428

429429
inline bool isZeroRankedTensorOrScalar(Type type) {

mlir/test/CAPI/sparse_tensor.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ static int testRoundtripEncoding(MlirContext ctx) {
4343
MlirAffineMap lvlToDim =
4444
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
4545
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
46-
enum MlirSparseTensorLevelType *lvlTypes =
47-
malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
46+
MlirSparseTensorLevelType *lvlTypes =
47+
malloc(sizeof(MlirSparseTensorLevelType) * lvlRank);
4848
for (int l = 0; l < lvlRank; ++l) {
4949
lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
50-
fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
50+
fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
5151
}
5252
// CHECK: posWidth: 32
5353
int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
7878
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
7979
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
8080
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
81-
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
82-
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
81+
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
82+
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
8383
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
8484
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
8585
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
9696
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
9797
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
9898
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
99-
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
100-
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
99+
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
100+
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
101101
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
102102
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
103103
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
114114
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
115115
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
116116
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
117-
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
118-
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
117+
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
118+
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
119119
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
120120
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
121121
// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
136136
// CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
137137
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
138138
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
139-
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
139+
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
140140
// CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
141141
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
142-
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
142+
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
143143
// CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
144144
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
145145
// CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>

mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
1515
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
1616
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
17-
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8
18-
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
19-
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
20-
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
21-
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
17+
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i64
18+
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
19+
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
20+
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
21+
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
2222
// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
2323
// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
2424
// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
2828
// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
2929
// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
3030
// CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
31-
// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
31+
// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
3232
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
3333
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
3434
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
2828
# CHECK: equal: True
2929
print(f"equal: {casted == parsed}")
3030

31-
# CHECK: lvl_types: [<LevelType.compressed: 8>]
31+
# CHECK: lvl_types: [8]
3232
print(f"lvl_types: {casted.lvl_types}")
3333
# CHECK: dim_to_lvl: (d0) -> (d0)
3434
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
7070
# CHECK: equal: True
7171
print(f"equal: {casted == parsed}")
7272

73-
# CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
73+
# CHECK: lvl_types: [4, 8]
7474
print(f"lvl_types: {casted.lvl_types}")
7575
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
7676
print(f"dim_to_lvl: {casted.dim_to_lvl}")

0 commit comments

Comments
 (0)