Skip to content

Commit c8fe1e1

Browse files
[mlir][sparse] Enable explicit and implicit value in sparse encoding
1. Explicit value means the non-zero value in a sparse tensor. If explicitVal is set, then all the non-zero values in the tensor have the same explicit value. It has the default value Attribute(). 2. Implicit value means the "zero" value in a sparse tensor. For now, we only support 0 as the implicit value but it could be extended in the future. It has the default value Attribute(). Example: #CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }> Note: this PR tests that implicitVal could be set to other values as well. The following PR will add verifier and reject any value that's not zero for implicitVal.
1 parent 76600ae commit c8fe1e1

File tree

9 files changed

+285
-37
lines changed

9 files changed

+285
-37
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
5353
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
5454
MlirContext ctx, intptr_t lvlRank,
5555
MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
56-
MlirAffineMap lvlTodim, int posWidth, int crdWidth);
56+
MlirAffineMap lvlTodim, int posWidth, int crdWidth,
57+
MlirAttribute explicitVal, MlirAttribute implicitVal);
5758

5859
/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
5960
MLIR_CAPI_EXPORTED intptr_t
@@ -85,6 +86,14 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
8586
MLIR_CAPI_EXPORTED int
8687
mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
8788

89+
/// Returns the explicit value of the `sparse_tensor.encoding` attribute.
90+
MLIR_CAPI_EXPORTED MlirAttribute
91+
mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr);
92+
93+
/// Returns the implicit value of the `sparse_tensor.encoding` attribute.
94+
MLIR_CAPI_EXPORTED MlirAttribute
95+
mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr);
96+
8897
MLIR_CAPI_EXPORTED unsigned
8998
mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
9099

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
183183
coordinate over all levels). The choices are `8`, `16`, `32`,
184184
`64`, or, the default, `0` to indicate a native bitwidth.
185185

186+
- The required explicit value for the sparse tensor. If explicitVal is set,
187+
then all the non-zero values in the tensor have the same explicit value.
188+
The default value Attribute() indicates that it is not set.
189+
190+
- The required implicit value for the sparse tensor. If implicitVal is set,
191+
then the "zero" value in the tensor is equal to the implicit value.
192+
For now, we only support `0` as the implicit value but it could be
193+
extended in the future. The default value Attribute() indicates that
194+
the implicit value is `0` (same type as the tensor element type).
195+
186196
Examples:
187197

188198
```mlir
@@ -226,6 +236,15 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
226236
}>
227237
... tensor<8x8xf64, #DCSC> ...
228238

239+
// Doubly compressed sparse column storage with specific
240+
// explicit and implicit values.
241+
#DCSC = #sparse_tensor.encoding<{
242+
map = (i, j) -> (j : compressed, i : compressed),
243+
explicitVal = 1 : i64,
244+
implicitVal = 0 : i64
245+
}>
246+
... tensor<8x8xi64, #DCSC> ...
247+
229248
// Block sparse row storage (2x3 blocks).
230249
#BSR = #sparse_tensor.encoding<{
231250
map = ( i, j ) ->
@@ -307,6 +326,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
307326
// The required bitwidth for coordinate storage.
308327
"unsigned":$crdWidth,
309328

329+
// The required explicit value.
330+
"::mlir::Attribute":$explicitVal,
331+
332+
// The required implicit value.
333+
"::mlir::Attribute":$implicitVal,
334+
310335
// A slice attribute for each dimension of the tensor type.
311336
ArrayRefParameter<
312337
"::mlir::sparse_tensor::SparseTensorDimSliceAttr",
@@ -319,14 +344,17 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
319344
CArg<"AffineMap", "{}">:$dimToLvl,
320345
CArg<"AffineMap", "{}">:$lvlToDim,
321346
CArg<"unsigned", "0">:$posWidth,
322-
CArg<"unsigned", "0">:$crdWidth), [{
347+
CArg<"unsigned", "0">:$crdWidth,
348+
CArg<"::mlir::Attribute", "{}">:$explicitVal,
349+
CArg<"::mlir::Attribute", "{}">:$implicitVal), [{
323350
if (!dimToLvl) {
324351
dimToLvl = ::mlir::AffineMap::getMultiDimIdentityMap(lvlTypes.size(), $_ctxt);
325352
}
326353
if (!lvlToDim) {
327354
lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
328355
}
329356
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
357+
explicitVal, implicitVal,
330358
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
331359
}]>
332360
];
@@ -353,6 +381,22 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
353381
/// reset to the default, and all other fields inherited from `this`.
354382
SparseTensorEncodingAttr withoutBitWidths() const;
355383

384+
/// Constructs a new encoding with the given explicit value
385+
/// and all other fields inherited from `this`.
386+
SparseTensorEncodingAttr withExplicitVal(Attribute explicitVal) const;
387+
388+
/// Constructs a new encoding with the explicit value
389+
/// reset to the default, and all other fields inherited from `this`.
390+
SparseTensorEncodingAttr withoutExplicitVal() const;
391+
392+
/// Constructs a new encoding with the given implicit value
393+
/// and all other fields inherited from `this`.
394+
SparseTensorEncodingAttr withImplicitVal(Attribute implicitVal) const;
395+
396+
/// Constructs a new encoding with the implicit value
397+
/// reset to the default, and all other fields inherited from `this`.
398+
SparseTensorEncodingAttr withoutImplicitVal() const;
399+
356400
/// Constructs a new encoding with the given dimSlices, and all
357401
/// other fields inherited from `this`.
358402
SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ class SparseTensorType {
115115
return withEncoding(enc.withoutBitWidths());
116116
}
117117

118+
SparseTensorType withExplicitVal(Attribute explicitVal) const {
119+
return withEncoding(enc.withExplicitVal(explicitVal));
120+
}
121+
122+
SparseTensorType withoutExplicitVal() const {
123+
return withEncoding(enc.withoutExplicitVal());
124+
}
125+
126+
SparseTensorType withImplicitVal(Attribute implicitVal) const {
127+
return withEncoding(enc.withImplicitVal(implicitVal));
128+
}
129+
130+
SparseTensorType withoutImplicitVal() const {
131+
return withEncoding(enc.withoutImplicitVal());
132+
}
133+
118134
SparseTensorType
119135
withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
120136
return withEncoding(enc.withDimSlices(dimSlices));
@@ -327,6 +343,12 @@ class SparseTensorType {
327343
/// Returns the position-overhead bitwidth, defaulting to zero.
328344
unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
329345

346+
/// Returns the explicit value, defaulting to empty Attribute.
347+
Attribute getExplicitVal() const { return enc.getExplicitVal(); }
348+
349+
/// Returns the implicit value, defaulting to empty Attribute.
350+
Attribute getImplicitVal() const { return enc.getImplicitVal(); }
351+
330352
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
331353
Type getCrdType() const { return enc.getCrdElemType(); }
332354

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,19 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
4242
[](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
4343
std::optional<MlirAffineMap> dimToLvl,
4444
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
45-
MlirContext context) {
45+
std::optional<MlirAttribute> explicitVal,
46+
std::optional<MlirAttribute> implicitVal, MlirContext context) {
4647
return cls(mlirSparseTensorEncodingAttrGet(
4748
context, lvlTypes.size(), lvlTypes.data(),
4849
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
4950
lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
50-
crdWidth));
51+
crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
52+
implicitVal ? *implicitVal : MlirAttribute{nullptr}));
5153
},
5254
py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
5355
py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
54-
py::arg("context") = py::none(),
56+
py::arg("explicit_val") = py::none(),
57+
py::arg("implicit_val") = py::none(), py::arg("context") = py::none(),
5558
"Gets a sparse_tensor.encoding from parameters.")
5659
.def_classmethod(
5760
"build_level_type",
@@ -97,6 +100,24 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
97100
mlirSparseTensorEncodingAttrGetPosWidth)
98101
.def_property_readonly("crd_width",
99102
mlirSparseTensorEncodingAttrGetCrdWidth)
103+
.def_property_readonly(
104+
"explicit_val",
105+
[](MlirAttribute self) -> std::optional<MlirAttribute> {
106+
MlirAttribute ret =
107+
mlirSparseTensorEncodingAttrGetExplicitVal(self);
108+
if (mlirAttributeIsNull(ret))
109+
return {};
110+
return ret;
111+
})
112+
.def_property_readonly(
113+
"implicit_val",
114+
[](MlirAttribute self) -> std::optional<MlirAttribute> {
115+
MlirAttribute ret =
116+
mlirSparseTensorEncodingAttrGetImplicitVal(self);
117+
if (mlirAttributeIsNull(ret))
118+
return {};
119+
return ret;
120+
})
100121
.def_property_readonly(
101122
"structured_n",
102123
[](MlirAttribute self) -> unsigned {

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,20 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
4444
return isa<SparseTensorEncodingAttr>(unwrap(attr));
4545
}
4646

47-
MlirAttribute
48-
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
49-
MlirSparseTensorLevelType const *lvlTypes,
50-
MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
51-
int posWidth, int crdWidth) {
47+
MlirAttribute mlirSparseTensorEncodingAttrGet(
48+
MlirContext ctx, intptr_t lvlRank,
49+
MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
50+
MlirAffineMap lvlToDim, int posWidth, int crdWidth,
51+
MlirAttribute explicitVal, MlirAttribute implicitVal) {
5252
SmallVector<LevelType> cppLvlTypes;
53+
5354
cppLvlTypes.reserve(lvlRank);
5455
for (intptr_t l = 0; l < lvlRank; ++l)
5556
cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l]));
56-
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
57-
unwrap(dimToLvl), unwrap(lvlToDim),
58-
posWidth, crdWidth));
57+
58+
return wrap(SparseTensorEncodingAttr::get(
59+
unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth,
60+
crdWidth, unwrap(explicitVal), unwrap(implicitVal)));
5961
}
6062

6163
MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
@@ -91,6 +93,14 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
9193
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
9294
}
9395

96+
MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) {
97+
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal());
98+
}
99+
100+
MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) {
101+
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal());
102+
}
103+
94104
MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
95105
enum MlirSparseTensorLevelFormat lvlFmt,
96106
const enum MlirSparseTensorLevelPropertyNondefault *properties,

0 commit comments

Comments
 (0)