Skip to content

Commit d808d92

Browse files
authored
[mlir][sparse] introduce sparse_tensor.reinterpret_map operation. (llvm#70378)
1 parent 3911810 commit d808d92

File tree

7 files changed

+263
-1
lines changed

7 files changed

+263
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
430430
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
431431

432432
//
433-
// Helper function to build IR related to the encoding.
433+
// Helper function to translate between level/dimension space.
434434
//
435+
SmallVector<int64_t> tranlateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
435436
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
436437

437438
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,55 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
208208
let hasVerifier = 1;
209209
}
210210

211+
def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemoryEffect]>,
212+
Arguments<(ins AnySparseTensor:$source)>,
213+
Results<(outs AnySparseTensor:$dest)> {
214+
let summary = "Reinterprets the dimension/level maps of the source tensor";
215+
let description = [{
216+
Reinterprets the dimension-to-level and level-to-dimension map specified in
217+
`source` according to the type of `dest`.
218+
`reinterpret_map` is a no-op and is introduced merely to resolve type conflicts.
219+
It does not make any modification to the source tensor and source/dest tensors
220+
are considered to be aliases.
221+
222+
`source` and `dest` tensors are "reinterpretable" if and only if they have
223+
the exactly same storage at a low level.
224+
That is, both `source` and `dest` has the same number of levels and level types,
225+
and their shape is consistent before and after `reinterpret_map`.
226+
227+
Example:
228+
```mlir
229+
#CSC = #sparse_tensor.encoding<{
230+
map = (d0, d1) -> (d1: dense, d0: compressed)
231+
}>
232+
#CSR = #sparse_tensor.encoding<{
233+
map = (d0, d1) -> (d0: dense, d1: compressed)
234+
}>
235+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<3x4xi32, #CSC> to tensor<4x3xi32, #CSR>
236+
237+
#BSR = #sparse_tensor.encoding<{
238+
map = ( i, j ) -> ( i floordiv 2 : dense,
239+
j floordiv 3 : compressed,
240+
i mod 2 : dense,
241+
j mod 3 : dense
242+
)
243+
}>
244+
#DSDD = #sparse_tensor.encoding<{
245+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
246+
}>
247+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR> to tensor<3x4x2x3xi32, #DSDD>
248+
```
249+
}];
250+
251+
let builders = [
252+
OpBuilder<(ins "SparseTensorEncodingAttr":$dstEnc, "Value":$source)>
253+
];
254+
255+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
256+
let hasFolder = 1;
257+
let hasVerifier = 1;
258+
}
259+
211260
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
212261
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
213262
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,12 @@ class SparseTensorType {
245245
/// Returns the dimension-shape.
246246
ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
247247

248+
/// Returns the Level-shape.
249+
SmallVector<DynSize> getLvlShape() const {
250+
return getEncoding().tranlateShape(getDimShape(),
251+
CrdTransDirectionKind::dim2lvl);
252+
}
253+
248254
/// Safely looks up the requested dimension-DynSize. If you intend
249255
/// to check the result with `ShapedType::isDynamic`, then see the
250256
/// `getStaticDimSize` method instead.
@@ -281,6 +287,7 @@ class SparseTensorType {
281287
/// `ShapedType::Trait<T>::getNumDynamicDims`.
282288
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
283289

290+
ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
284291
DimLevelType getLvlType(Level l) const {
285292
// This OOB check is for dense-tensors, since this class knows
286293
// their lvlRank (whereas STEA::getLvlType will/can only check

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,55 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
412412
return getStaticDimSliceStride(toOrigDim(*this, lvl));
413413
}
414414

415+
SmallVector<int64_t>
416+
SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
417+
CrdTransDirectionKind dir) const {
418+
if (isIdentity())
419+
return SmallVector<int64_t>(srcShape);
420+
421+
SmallVector<int64_t> ret;
422+
unsigned rank =
423+
dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
424+
ret.reserve(rank);
425+
426+
if (isPermutation()) {
427+
for (unsigned r = 0; r < rank; r++) {
428+
unsigned trans = dir == CrdTransDirectionKind::dim2lvl
429+
? toOrigDim(*this, r)
430+
: toStoredDim(*this, r);
431+
ret.push_back(srcShape[trans]);
432+
}
433+
return ret;
434+
}
435+
436+
// Handle non-permutation maps.
437+
AffineMap transMap =
438+
dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
439+
440+
SmallVector<AffineExpr> dimRep;
441+
dimRep.reserve(srcShape.size());
442+
for (int64_t sz : srcShape) {
443+
if (!ShapedType::isDynamic(sz)) {
444+
// Push back the max coordinate for the given dimension/level size.
445+
dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
446+
} else {
447+
// A dynamic size, use a AffineDimExpr to symbolize the value.
448+
dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
449+
}
450+
};
451+
452+
for (AffineExpr exp : transMap.getResults()) {
453+
// Do constant propagation on the affine map.
454+
AffineExpr evalExp =
455+
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
456+
if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
457+
ret.push_back(c.getValue() + 1);
458+
else
459+
ret.push_back(ShapedType::kDynamic);
460+
}
461+
return ret;
462+
}
463+
415464
ValueRange
416465
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
417466
ValueRange crds,
@@ -1286,6 +1335,64 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
12861335
return {};
12871336
}
12881337

1338+
void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1339+
SparseTensorEncodingAttr dstEnc, Value source) {
1340+
auto srcStt = getSparseTensorType(source);
1341+
SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1342+
SmallVector<int64_t> dstDimShape =
1343+
dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1344+
auto dstTp =
1345+
RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1346+
return build(odsBuilder, odsState, dstTp, source);
1347+
}
1348+
1349+
LogicalResult ReinterpretMapOp::verify() {
1350+
auto srcStt = getSparseTensorType(getSource());
1351+
auto dstStt = getSparseTensorType(getDest());
1352+
ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
1353+
ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
1354+
1355+
if (srcLvlTps.size() != dstLvlTps.size())
1356+
return emitError("Level rank mismatch between source/dest tensors");
1357+
1358+
for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1359+
if (srcLvlTp != dstLvlTp)
1360+
return emitError("Level type mismatch between source/dest tensors");
1361+
1362+
if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1363+
srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1364+
return emitError("Crd/Pos width mismatch between source/dest tensors");
1365+
}
1366+
1367+
if (srcStt.getElementType() != dstStt.getElementType())
1368+
return emitError("Element type mismatch between source/dest tensors");
1369+
1370+
SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
1371+
SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
1372+
for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1373+
if (srcLvlSz != dstLvlSz) {
1374+
// Should we allow one side to be dynamic size, e.g., <?x?> should be
1375+
// compatible to <3x4>? For now, we require all the level sizes to be
1376+
// *exactly* matched for simplicity.
1377+
return emitError("Level size mismatch between source/dest tensors");
1378+
}
1379+
}
1380+
1381+
return success();
1382+
}
1383+
1384+
OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1385+
if (getSource().getType() == getDest().getType())
1386+
return getSource();
1387+
1388+
if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1389+
// A -> B, B -> A ==> A
1390+
if (def.getSource().getType() == getDest().getType())
1391+
return def.getSource();
1392+
}
1393+
return {};
1394+
}
1395+
12891396
LogicalResult ToPositionsOp::verify() {
12901397
auto e = getSparseTensorEncoding(getTensor().getType());
12911398
if (failed(lvlIsInBounds(getLevel(), getTensor())))

mlir/test/Dialect/SparseTensor/fold.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,18 @@ func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
111111
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
112112
return %l0 : index
113113
}
114+
115+
#DSDD = #sparse_tensor.encoding<{
116+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
117+
}>
118+
119+
120+
// CHECK-LABEL: func.func @sparse_reinterpret_map(
121+
// CHECK-NOT: sparse_tensor.reinterpret_map
122+
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<6x12xi32, #BSR> {
123+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
124+
to tensor<3x4x2x3xi32, #DSDD>
125+
%t2 = sparse_tensor.reinterpret_map %t1 : tensor<3x4x2x3xi32, #DSDD>
126+
to tensor<6x12xi32, #BSR>
127+
return %t2 : tensor<6x12xi32, #BSR>
128+
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,66 @@ func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
964964
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
965965
return %l0 : index
966966
}
967+
968+
// -----
969+
970+
#BSR = #sparse_tensor.encoding<{
971+
map = ( i, j ) -> ( i floordiv 2 : dense,
972+
j floordiv 3 : compressed,
973+
i mod 2 : dense,
974+
j mod 3 : dense
975+
)
976+
}>
977+
978+
#DSDC = #sparse_tensor.encoding<{
979+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: compressed)
980+
}>
981+
982+
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDC> {
983+
// expected-error@+1 {{Level type mismatch between source/dest tensors}}
984+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
985+
to tensor<3x4x2x3xf32, #DSDC>
986+
return %t1 : tensor<3x4x2x3xf32, #DSDC>
987+
}
988+
989+
// -----
990+
991+
#BSR = #sparse_tensor.encoding<{
992+
map = ( i, j ) -> ( i floordiv 2 : dense,
993+
j floordiv 3 : compressed,
994+
i mod 2 : dense,
995+
j mod 3 : dense
996+
)
997+
}>
998+
999+
#DSDD = #sparse_tensor.encoding<{
1000+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
1001+
}>
1002+
1003+
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDD> {
1004+
// expected-error@+1 {{Element type mismatch between source/dest tensors}}
1005+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
1006+
to tensor<3x4x2x3xf32, #DSDD>
1007+
return %t1 : tensor<3x4x2x3xf32, #DSDD>
1008+
}
1009+
1010+
// -----
1011+
1012+
#BSR = #sparse_tensor.encoding<{
1013+
map = ( i, j ) -> ( i floordiv 2 : dense,
1014+
j floordiv 3 : compressed,
1015+
i mod 2 : dense,
1016+
j mod 3 : dense
1017+
)
1018+
}>
1019+
1020+
#DSDD = #sparse_tensor.encoding<{
1021+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
1022+
}>
1023+
1024+
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x4xi32, #DSDD> {
1025+
// expected-error@+1 {{Level size mismatch between source/dest tensors}}
1026+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
1027+
to tensor<3x4x2x4xi32, #DSDD>
1028+
return %t1 : tensor<3x4x2x4xi32, #DSDD>
1029+
}

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,23 @@ func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
690690
%l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
691691
return %l0 : index
692692
}
693+
694+
// -----
695+
696+
#BSR = #sparse_tensor.encoding<{
697+
map = ( i, j ) -> ( i floordiv 2 : dense,
698+
j floordiv 3 : compressed,
699+
i mod 2 : dense,
700+
j mod 3 : dense
701+
)
702+
}>
703+
704+
#DSDD = #sparse_tensor.encoding<{
705+
map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
706+
}>
707+
708+
func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xi32, #DSDD> {
709+
%t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
710+
to tensor<3x4x2x3xi32, #DSDD>
711+
return %t1 : tensor<3x4x2x3xi32, #DSDD>
712+
}

0 commit comments

Comments
 (0)