Skip to content

[mlir][sparse] introduce sparse_tensor.reinterpret_map operation. #70378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;

//
// Helper function to build IR related to the encoding.
// Helper function to translate between level/dimension space.
//
SmallVector<int64_t> tranlateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;

//
Expand Down
49 changes: 49 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,55 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
let hasVerifier = 1;
}

def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemoryEffect]>,
Arguments<(ins AnySparseTensor:$source)>,
Results<(outs AnySparseTensor:$dest)> {
let summary = "Reinterprets the dimension/level maps of the source tensor";
let description = [{
Reinterprets the dimension-to-level and level-to-dimension map specified in
`source` according to the type of `dest`.
`reinterpret_map` is a no-op and is introduced merely to resolve type conflicts.
It does not make any modification to the source tensor and source/dest tensors
are considered to be aliases.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, we probably need to follow up with the right bufferization interfaces later to make that clear


`source` and `dest` tensors are "reinterpretable" if and only if they have
the exactly same storage at a low level.
That is, both `source` and `dest` has the same number of levels and level types,
and their shape is consistent before and after `reinterpret_map`.

Example:
```mlir
#CSC = #sparse_tensor.encoding<{
map = (d0, d1) -> (d1: dense, d0: compressed)
}>
#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0: dense, d1: compressed)
}>
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<3x4xi32, #CSC> to tensor<4x3xi32, #CSR>

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>
#DSDD = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
}>
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR> to tensor<3x4x2x3xi32, #DSDD>
```
}];

let builders = [
OpBuilder<(ins "SparseTensorEncodingAttr":$dstEnc, "Value":$source)>
];

let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
let hasVerifier = 1;
}

def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ class SparseTensorType {
/// Returns the dimension-shape.
ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }

/// Returns the Level-shape.
SmallVector<DynSize> getLvlShape() const {
return getEncoding().tranlateShape(getDimShape(),
CrdTransDirectionKind::dim2lvl);
}

/// Safely looks up the requested dimension-DynSize. If you intend
/// to check the result with `ShapedType::isDynamic`, then see the
/// `getStaticDimSize` method instead.
Expand Down Expand Up @@ -281,6 +287,7 @@ class SparseTensorType {
/// `ShapedType::Trait<T>::getNumDynamicDims`.
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }

ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
DimLevelType getLvlType(Level l) const {
// This OOB check is for dense-tensors, since this class knows
// their lvlRank (whereas STEA::getLvlType will/can only check
Expand Down
107 changes: 107 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,55 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}

SmallVector<int64_t>
SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
CrdTransDirectionKind dir) const {
if (isIdentity())
return SmallVector<int64_t>(srcShape);

SmallVector<int64_t> ret;
unsigned rank =
dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
ret.reserve(rank);

if (isPermutation()) {
for (unsigned r = 0; r < rank; r++) {
unsigned trans = dir == CrdTransDirectionKind::dim2lvl
? toOrigDim(*this, r)
: toStoredDim(*this, r);
ret.push_back(srcShape[trans]);
}
return ret;
}

// Handle non-permutation maps.
AffineMap transMap =
dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();

SmallVector<AffineExpr> dimRep;
dimRep.reserve(srcShape.size());
for (int64_t sz : srcShape) {
if (!ShapedType::isDynamic(sz)) {
// Push back the max coordinate for the given dimension/level size.
dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
} else {
// A dynamic size, use a AffineDimExpr to symbolize the value.
dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
}
};

for (AffineExpr exp : transMap.getResults()) {
// Do constant propagation on the affine map.
AffineExpr evalExp =
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
ret.push_back(c.getValue() + 1);
else
ret.push_back(ShapedType::kDynamic);
}
return ret;
}

ValueRange
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
ValueRange crds,
Expand Down Expand Up @@ -1292,6 +1341,64 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
return {};
}

void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
SparseTensorEncodingAttr dstEnc, Value source) {
auto srcStt = getSparseTensorType(source);
SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
SmallVector<int64_t> dstDimShape =
dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
auto dstTp =
RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
return build(odsBuilder, odsState, dstTp, source);
}

LogicalResult ReinterpretMapOp::verify() {
auto srcStt = getSparseTensorType(getSource());
auto dstStt = getSparseTensorType(getDest());
ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();

if (srcLvlTps.size() != dstLvlTps.size())
return emitError("Level rank mismatch between source/dest tensors");

for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
if (srcLvlTp != dstLvlTp)
return emitError("Level type mismatch between source/dest tensors");

if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
return emitError("Crd/Pos width mismatch between source/dest tensors");
}

if (srcStt.getElementType() != dstStt.getElementType())
return emitError("Element type mismatch between source/dest tensors");

SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
if (srcLvlSz != dstLvlSz) {
// Should we allow one side to be dynamic size, e.g., <?x?> should be
// compatible to <3x4>? For now, we require all the level sizes to be
// *exactly* matched for simplicity.
return emitError("Level size mismatch between source/dest tensors");
}
}

return success();
}

OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
if (getSource().getType() == getDest().getType())
return getSource();

if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
// A -> B, B -> A ==> A
if (def.getSource().getType() == getDest().getType())
return def.getSource();
}
return {};
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/SparseTensor/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,18 @@ func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}

#DSDD = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
}>


// CHECK-LABEL: func.func @sparse_reinterpret_map(
// CHECK-NOT: sparse_tensor.reinterpret_map
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<6x12xi32, #BSR> {
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
to tensor<3x4x2x3xi32, #DSDD>
%t2 = sparse_tensor.reinterpret_map %t1 : tensor<3x4x2x3xi32, #DSDD>
to tensor<6x12xi32, #BSR>
return %t2 : tensor<6x12xi32, #BSR>
}
63 changes: 63 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,66 @@ func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

#DSDC = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: compressed)
}>

func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDC> {
// expected-error@+1 {{Level type mismatch between source/dest tensors}}
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
to tensor<3x4x2x3xf32, #DSDC>
return %t1 : tensor<3x4x2x3xf32, #DSDC>
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

#DSDD = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
}>

func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDD> {
// expected-error@+1 {{Element type mismatch between source/dest tensors}}
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
to tensor<3x4x2x3xf32, #DSDD>
return %t1 : tensor<3x4x2x3xf32, #DSDD>
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

#DSDD = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
}>

func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x4xi32, #DSDD> {
// expected-error@+1 {{Level size mismatch between source/dest tensors}}
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
to tensor<3x4x2x4xi32, #DSDD>
return %t1 : tensor<3x4x2x4xi32, #DSDD>
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,23 @@ func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
return %l0 : index
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

#DSDD = #sparse_tensor.encoding<{
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
}>

func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xi32, #DSDD> {
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
to tensor<3x4x2x3xi32, #DSDD>
return %t1 : tensor<3x4x2x3xi32, #DSDD>
}