Skip to content

Commit 836411b

Browse files
authored
[mlir][sparse] add lvlToDim field to sparse tensor encoding (#67194)
Note the new surface syntax allows for defining a dimToLvl and lvlToDim map at once (where usually the latter can be inferred from the former, but not always). This revision adds storage for the latter, together with some intial boilerplate. The actual support (inference, validation, printing, etc.) is still TBD of course.
1 parent 8ea7430 commit 836411b

File tree

8 files changed

+54
-22
lines changed

8 files changed

+54
-22
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ MLIR_CAPI_EXPORTED bool
5151
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
5252

5353
/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
54+
/// TODO: add a version that supplied lvlToDim when it cannot be inferred
5455
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
5556
MlirContext ctx, intptr_t lvlRank,
5657
enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
@@ -69,6 +70,11 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
6970
MLIR_CAPI_EXPORTED MlirAffineMap
7071
mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr);
7172

73+
/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding`
74+
/// attribute.
75+
MLIR_CAPI_EXPORTED MlirAffineMap
76+
mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr);
77+
7278
/// Returns the position bitwidth of the `sparse_tensor.encoding` attribute.
7379
MLIR_CAPI_EXPORTED int
7480
mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
237237
//
238238
let parameters = (
239239
ins
240-
// A level-type for each level of the sparse storage.
240+
// A level-type for each level of the sparse storage
241+
// (consists of a level-format combined with level-properties).
241242
ArrayRefParameter<
242243
"::mlir::sparse_tensor::DimLevelType",
243244
"level-types"
@@ -246,6 +247,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
246247
// A mapping from dimension-coordinates to level-coordinates.
247248
"AffineMap":$dimToLvl,
248249

250+
// A mapping from level-coordinates to dimension-coordinates.
251+
"AffineMap":$lvlToDim,
252+
249253
// The required bitwidth for position storage.
250254
"unsigned":$posWidth,
251255

@@ -262,9 +266,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
262266
let builders = [
263267
AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
264268
"AffineMap":$dimToLvl,
269+
"AffineMap":$lvlToDim,
265270
"unsigned":$posWidth,
266271
"unsigned":$crdWidth), [{
267-
return $_get($_ctxt, lvlTypes, dimToLvl, posWidth, crdWidth,
272+
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
268273
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
269274
}]>
270275
];

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ namespace sparse_tensor {
4545
///
4646
class SparseTensorType {
4747
public:
48-
// We memoize `lvlRank` and `dimToLvl` to avoid repeating the
49-
// conditionals throughout the rest of the class.
48+
// We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
49+
// the conditionals throughout the rest of the class.
5050
SparseTensorType(RankedTensorType rtp)
5151
: rtp(rtp), enc(getSparseTensorEncoding(rtp)),
5252
lvlRank(enc ? enc.getLvlRank() : getDimRank()),
53-
dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()) {
53+
dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
54+
lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
5455
assert(rtp && "got null RankedTensorType");
5556
assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
5657
}
@@ -201,6 +202,9 @@ class SparseTensorType {
201202
/// see `hasSameDimToLvl` instead.
202203
AffineMap getDimToLvl() const { return dimToLvl; }
203204

205+
/// Returns the lvlToDiml mapping (or the null-map for the identity).
206+
AffineMap getLvlToDim() const { return lvlToDim; }
207+
204208
/// Returns the dimToLvl mapping, where the identity map is expanded out
205209
/// into a full `AffineMap`. This method is provided as a convenience,
206210
/// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
@@ -306,6 +310,7 @@ class SparseTensorType {
306310
// Memoized to avoid frequent redundant conditionals.
307311
const Level lvlRank;
308312
const AffineMap dimToLvl;
313+
const AffineMap lvlToDim;
309314
};
310315

311316
/// Convenience method to abbreviate wrapping `getRankedTensorType`.

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
4343
[](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
4444
std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
4545
MlirContext context) {
46+
// TODO: provide dimToLvl
4647
return cls(mlirSparseTensorEncodingAttrGet(
4748
context, lvlTypes.size(), lvlTypes.data(),
4849
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,20 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
5454
cppLvlTypes.reserve(lvlRank);
5555
for (intptr_t l = 0; l < lvlRank; ++l)
5656
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
57-
return wrap(SparseTensorEncodingAttr::get(
58-
unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth));
57+
mlir::AffineMap lvlToDim; // TODO: provide in API
58+
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
59+
unwrap(dimToLvl), lvlToDim,
60+
posWidth, crdWidth));
5961
}
6062

6163
MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
6264
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
6365
}
6466

67+
MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
68+
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
69+
}
70+
6571
intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
6672
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
6773
}

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,10 @@ Type SparseTensorEncodingAttr::getCrdType() const {
293293
SparseTensorEncodingAttr
294294
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
295295
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
296+
// TODO: infer lvlToDim
296297
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
297-
getPosWidth(), getCrdWidth());
298+
/*lvlToDim*/ AffineMap(), getPosWidth(),
299+
getCrdWidth());
298300
}
299301

300302
SparseTensorEncodingAttr
@@ -311,7 +313,8 @@ SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
311313
unsigned crdWidth) const {
312314
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
313315
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
314-
getDimToLvl(), posWidth, crdWidth);
316+
getDimToLvl(), getLvlToDim(), posWidth,
317+
crdWidth);
315318
}
316319

317320
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
@@ -321,8 +324,8 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
321324
SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
322325
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
323326
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
324-
getDimToLvl(), getPosWidth(),
325-
getCrdWidth(), dimSlices);
327+
getDimToLvl(), getLvlToDim(),
328+
getPosWidth(), getCrdWidth(), dimSlices);
326329
}
327330

328331
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
@@ -576,8 +579,10 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
576579
#undef RETURN_ON_FAIL
577580

578581
// Construct struct-like storage for attribute.
582+
AffineMap lvlToDim; // TODO: infer
579583
return parser.getChecked<SparseTensorEncodingAttr>(
580-
parser.getContext(), lvlTypes, dimToLvl, posWidth, crdWidth, dimSlices);
584+
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
585+
dimSlices);
581586
}
582587

583588
void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -608,10 +613,12 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
608613
printer << " }>";
609614
}
610615

611-
LogicalResult SparseTensorEncodingAttr::verify(
612-
function_ref<InFlightDiagnostic()> emitError,
613-
ArrayRef<DimLevelType> lvlTypes, AffineMap dimToLvl, unsigned posWidth,
614-
unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
616+
LogicalResult
617+
SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
618+
ArrayRef<DimLevelType> lvlTypes,
619+
AffineMap dimToLvl, AffineMap lvlToDim,
620+
unsigned posWidth, unsigned crdWidth,
621+
ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
615622
if (!acceptBitWidth(posWidth))
616623
return emitError() << "unexpected position bitwidth: " << posWidth;
617624
if (!acceptBitWidth(crdWidth))
@@ -631,7 +638,7 @@ LogicalResult SparseTensorEncodingAttr::verify(
631638
return emitError()
632639
<< "level-rank mismatch between dimToLvl and lvlTypes: "
633640
<< dimToLvl.getNumResults() << " != " << lvlRank;
634-
// TODO: The following is attempting to match the old error-conditions
641+
// TODO: The following is attempting to match the old error-conditions
635642
// from prior to merging dimOrdering and higherOrdering into dimToLvl.
636643
// That is, we currently require `dimToLvl` to be either a permutation
637644
// (as when higherOrdering is the identity) or expansive (as per the
@@ -674,7 +681,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
674681
// Check structural integrity. In particular, this ensures that the
675682
// level-rank is coherent across all the fields.
676683
RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
677-
getPosWidth(), getCrdWidth(), getDimSlices()))
684+
getLvlToDim(), getPosWidth(), getCrdWidth(),
685+
getDimSlices()))
678686
// Check integrity with tensor type specifics. In particular, we
679687
// need only check that the dimension-rank of the tensor agrees with
680688
// the dimension-rank of the encoding.
@@ -763,8 +771,9 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
763771
// default value.
764772
unsigned posWidth = src.getPosWidth();
765773
unsigned crdWidth = src.getCrdWidth();
774+
AffineMap invPerm; // TODO
766775
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
767-
posWidth, crdWidth);
776+
invPerm, posWidth, crdWidth);
768777
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
769778
}
770779

@@ -836,6 +845,7 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
836845
return SparseTensorEncodingAttr::get(
837846
enc.getContext(), dlts,
838847
AffineMap(), // dimToLvl (irrelevant to storage specifier)
848+
AffineMap(), // lvlToDim (irrelevant to storage specifier)
839849
// Always use `index` for memSize and lvlSize instead of reusing
840850
// `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
841851
// value for different bitwidth, it also avoids casting between index and

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
986986
const auto dstEnc = SparseTensorEncodingAttr::get(
987987
op->getContext(),
988988
SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
989-
srcEnc.getPosWidth(), srcEnc.getCrdWidth());
989+
AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
990990
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
991991
Value iter = NewCallParams(rewriter, loc)
992992
.genBuffers(dstTp.withEncoding(dstEnc), dimSizes)

mlir/test/CAPI/sparse_tensor.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ static int testRoundtripEncoding(MlirContext ctx) {
5454
// CHECK: crdWidth: 64
5555
int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
5656
fprintf(stderr, "crdWidth: %d\n", crdWidth);
57-
57+
// TODO: lvlToDim
5858
MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
5959
ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
6060
mlirAttributeDump(newAttr); // For debugging filecheck output.
6161
// CHECK: equal: 1
6262
fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
63-
6463
free(lvlTypes);
6564
return 0;
6665
}

0 commit comments

Comments
 (0)