Skip to content

[mlir][sparse] add lvlToDim field to sparse tensor encoding #67194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ MLIR_CAPI_EXPORTED bool
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);

/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
/// TODO: add a version that supplied lvlToDim when it cannot be inferred
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
MlirContext ctx, intptr_t lvlRank,
enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
Expand All @@ -69,6 +70,11 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
MLIR_CAPI_EXPORTED MlirAffineMap
mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr);

/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding`
/// attribute.
MLIR_CAPI_EXPORTED MlirAffineMap
mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr);

/// Returns the position bitwidth of the `sparse_tensor.encoding` attribute.
MLIR_CAPI_EXPORTED int
mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
//
let parameters = (
ins
// A level-type for each level of the sparse storage.
// A level-type for each level of the sparse storage
// (consists of a level-format combined with level-properties).
ArrayRefParameter<
"::mlir::sparse_tensor::DimLevelType",
"level-types"
Expand All @@ -246,6 +247,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
// A mapping from dimension-coordinates to level-coordinates.
"AffineMap":$dimToLvl,

// A mapping from level-coordinates to dimension-coordinates.
"AffineMap":$lvlToDim,

// The required bitwidth for position storage.
"unsigned":$posWidth,

Expand All @@ -262,9 +266,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
let builders = [
AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
"AffineMap":$dimToLvl,
"AffineMap":$lvlToDim,
"unsigned":$posWidth,
"unsigned":$crdWidth), [{
return $_get($_ctxt, lvlTypes, dimToLvl, posWidth, crdWidth,
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
}]>
];
Expand Down
11 changes: 8 additions & 3 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ namespace sparse_tensor {
///
class SparseTensorType {
public:
// We memoize `lvlRank` and `dimToLvl` to avoid repeating the
// conditionals throughout the rest of the class.
// We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
// the conditionals throughout the rest of the class.
SparseTensorType(RankedTensorType rtp)
: rtp(rtp), enc(getSparseTensorEncoding(rtp)),
lvlRank(enc ? enc.getLvlRank() : getDimRank()),
dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()) {
dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
assert(rtp && "got null RankedTensorType");
assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
}
Expand Down Expand Up @@ -201,6 +202,9 @@ class SparseTensorType {
/// see `hasSameDimToLvl` instead.
AffineMap getDimToLvl() const { return dimToLvl; }

/// Returns the lvlToDiml mapping (or the null-map for the identity).
AffineMap getLvlToDim() const { return lvlToDim; }

/// Returns the dimToLvl mapping, where the identity map is expanded out
/// into a full `AffineMap`. This method is provided as a convenience,
/// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
Expand Down Expand Up @@ -306,6 +310,7 @@ class SparseTensorType {
// Memoized to avoid frequent redundant conditionals.
const Level lvlRank;
const AffineMap dimToLvl;
const AffineMap lvlToDim;
};

/// Convenience method to abbreviate wrapping `getRankedTensorType`.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/DialectSparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
[](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
MlirContext context) {
// TODO: provide dimToLvl
return cls(mlirSparseTensorEncodingAttrGet(
context, lvlTypes.size(), lvlTypes.data(),
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/CAPI/Dialect/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
return wrap(SparseTensorEncodingAttr::get(
unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth));
mlir::AffineMap lvlToDim; // TODO: provide in API
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
unwrap(dimToLvl), lvlToDim,
posWidth, crdWidth));
}

MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
}

MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
}

intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
}
Expand Down
34 changes: 22 additions & 12 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ Type SparseTensorEncodingAttr::getCrdType() const {
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
// TODO: infer lvlToDim
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
getPosWidth(), getCrdWidth());
/*lvlToDim*/ AffineMap(), getPosWidth(),
getCrdWidth());
}

SparseTensorEncodingAttr
Expand All @@ -311,7 +313,8 @@ SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
unsigned crdWidth) const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
getDimToLvl(), posWidth, crdWidth);
getDimToLvl(), getLvlToDim(), posWidth,
crdWidth);
}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
Expand All @@ -321,8 +324,8 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
getDimToLvl(), getPosWidth(),
getCrdWidth(), dimSlices);
getDimToLvl(), getLvlToDim(),
getPosWidth(), getCrdWidth(), dimSlices);
}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
Expand Down Expand Up @@ -576,8 +579,10 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
#undef RETURN_ON_FAIL

// Construct struct-like storage for attribute.
AffineMap lvlToDim; // TODO: infer
return parser.getChecked<SparseTensorEncodingAttr>(
parser.getContext(), lvlTypes, dimToLvl, posWidth, crdWidth, dimSlices);
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
dimSlices);
}

void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
Expand Down Expand Up @@ -608,10 +613,12 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
printer << " }>";
}

LogicalResult SparseTensorEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DimLevelType> lvlTypes, AffineMap dimToLvl, unsigned posWidth,
unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
LogicalResult
SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DimLevelType> lvlTypes,
AffineMap dimToLvl, AffineMap lvlToDim,
unsigned posWidth, unsigned crdWidth,
ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
if (!acceptBitWidth(posWidth))
return emitError() << "unexpected position bitwidth: " << posWidth;
if (!acceptBitWidth(crdWidth))
Expand All @@ -631,7 +638,7 @@ LogicalResult SparseTensorEncodingAttr::verify(
return emitError()
<< "level-rank mismatch between dimToLvl and lvlTypes: "
<< dimToLvl.getNumResults() << " != " << lvlRank;
// TODO: The following is attempting to match the old error-conditions
// TODO: The following is attempting to match the old error-conditions
// from prior to merging dimOrdering and higherOrdering into dimToLvl.
// That is, we currently require `dimToLvl` to be either a permutation
// (as when higherOrdering is the identity) or expansive (as per the
Expand Down Expand Up @@ -674,7 +681,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
// Check structural integrity. In particular, this ensures that the
// level-rank is coherent across all the fields.
RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
getPosWidth(), getCrdWidth(), getDimSlices()))
getLvlToDim(), getPosWidth(), getCrdWidth(),
getDimSlices()))
// Check integrity with tensor type specifics. In particular, we
// need only check that the dimension-rank of the tensor agrees with
// the dimension-rank of the encoding.
Expand Down Expand Up @@ -763,8 +771,9 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
// default value.
unsigned posWidth = src.getPosWidth();
unsigned crdWidth = src.getCrdWidth();
AffineMap invPerm; // TODO
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
posWidth, crdWidth);
invPerm, posWidth, crdWidth);
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
}

Expand Down Expand Up @@ -836,6 +845,7 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
return SparseTensorEncodingAttr::get(
enc.getContext(), dlts,
AffineMap(), // dimToLvl (irrelevant to storage specifier)
AffineMap(), // lvlToDim (irrelevant to storage specifier)
// Always use `index` for memSize and lvlSize instead of reusing
// `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
// value for different bitwidth, it also avoids casting between index and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
const auto dstEnc = SparseTensorEncodingAttr::get(
op->getContext(),
SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
srcEnc.getPosWidth(), srcEnc.getCrdWidth());
AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
Expand Down
3 changes: 1 addition & 2 deletions mlir/test/CAPI/sparse_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: crdWidth: 64
int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
fprintf(stderr, "crdWidth: %d\n", crdWidth);

// TODO: lvlToDim
MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
mlirAttributeDump(newAttr); // For debugging filecheck output.
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));

free(lvlTypes);
return 0;
}
Expand Down