Skip to content

Commit f0f5fdf

Browse files
authored
[mlir][sparse] introduce sparse_tensor.lvl operation. (#69978)
1 parent e6005d5 commit f0f5fdf

File tree

6 files changed

+205
-1
lines changed

6 files changed

+205
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def SparseTensor_Dialect : Dialect {
9090

9191
let useDefaultAttributePrinterParser = 1;
9292
let useDefaultTypePrinterParser = 1;
93+
let hasConstantMaterializer = 1;
9394
}
9495

9596
#endif // SPARSETENSOR_BASE

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,67 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
521521
}
522522

523523
//===----------------------------------------------------------------------===//
524-
// Sparse Tensor Coordinate Translation Operation.
524+
// Sparse Tensor Coordinate Operations.
525525
//===----------------------------------------------------------------------===//
526526

527+
def SparseTensor_LvlOp : SparseTensor_Op<"lvl", [ConditionallySpeculatable, NoMemoryEffect]>,
528+
Arguments<(ins AnySparseTensor:$source, Index:$index)>,
529+
Results<(outs Index:$result)> {
530+
let summary = "level index operation";
531+
let description = [{
532+
The `sparse_tensor.lvl` behaves similar to `tensor.dim` operation.
533+
It takes a sparse tensor and a level operand of type `index` and returns
534+
the size of the requested level of the given sparse tensor.
535+
If the sparse tensor has an identity dimension to level mapping, it returns
536+
the same result as `tensor.dim`.
537+
If the level index is out of bounds, the behavior is undefined.
538+
539+
Example:
540+
541+
```mlir
542+
#BSR = #sparse_tensor.encoding<{
543+
map = ( i, j ) ->
544+
( i floordiv 2 : dense,
545+
j floordiv 3 : compressed,
546+
i mod 2 : dense,
547+
j mod 3 : dense
548+
)
549+
}>
550+
551+
// Always returns 2 (4 floordiv 2), can be constant folded:
552+
%c0 = arith.constant 0 : index
553+
%x = sparse_tensor.lvl %A, %c0 : tensor<4x?xf32, #BSR>
554+
555+
// Return the dynamic dimension of %A computed by %j mod 3.
556+
%c1 = arith.constant 1 : index
557+
%y = sparse_tensor.lvl %A, %c1 : tensor<4x?xf32, #BSR>
558+
559+
// Always return 3 (since j mod 3 < 3), can be constant fold
560+
%c3 = arith.constant 3 : index
561+
%y = sparse_tensor.lvl %A, %c3 : tensor<4x?xf32, #BSR>
562+
```
563+
}];
564+
565+
let assemblyFormat = [{
566+
attr-dict $source `,` $index `:` type($source)
567+
}];
568+
569+
let builders = [
570+
OpBuilder<(ins "Value":$source, "int64_t":$index)>
571+
];
572+
573+
let extraClassDeclaration = [{
574+
/// Helper function to get the index as a simple integer if it is constant.
575+
std::optional<uint64_t> getConstantLvlIndex();
576+
577+
/// Interface method for ConditionallySpeculatable.
578+
Speculation::Speculatability getSpeculatability();
579+
}];
580+
581+
let hasVerifier = 1;
582+
let hasFolder = 1;
583+
}
584+
527585
def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
528586
Arguments<(ins Variadic<Index>:$in_crds,
529587
SparseTensorCrdTransDirectionAttr:$direction,

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,84 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
12081208
return success();
12091209
}
12101210

1211+
LogicalResult LvlOp::verify() {
1212+
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1213+
auto stt = getSparseTensorType(getSource());
1214+
if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1215+
emitError("Level index exceeds the rank of the input sparse tensor");
1216+
}
1217+
return success();
1218+
}
1219+
1220+
std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1221+
return getConstantIntValue(getIndex());
1222+
}
1223+
1224+
Speculation::Speculatability LvlOp::getSpeculatability() {
1225+
auto constantIndex = getConstantLvlIndex();
1226+
if (!constantIndex)
1227+
return Speculation::NotSpeculatable;
1228+
1229+
assert(constantIndex <
1230+
cast<RankedTensorType>(getSource().getType()).getRank());
1231+
return Speculation::Speculatable;
1232+
}
1233+
1234+
OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1235+
auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1236+
if (!lvlIndex)
1237+
return {};
1238+
1239+
Level lvl = lvlIndex.getAPSInt().getZExtValue();
1240+
auto stt = getSparseTensorType(getSource());
1241+
if (lvl >= stt.getLvlRank()) {
1242+
// Follows the same convention used by tensor.dim operation. Out of bound
1243+
// indices produce undefined behavior but are still valid IR. Don't choke on
1244+
// them.
1245+
return {};
1246+
}
1247+
1248+
// Helper lambda to build an IndexAttr.
1249+
auto getIndexAttr = [this](int64_t lvlSz) {
1250+
return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1251+
};
1252+
1253+
// TODO: we can remove this after SparseTensorEncoding always returns non-null
1254+
// dimToLvl map.
1255+
ArrayRef<DynSize> shape = stt.getDimShape();
1256+
if (stt.isPermutation()) {
1257+
Dimension dim = toOrigDim(stt, lvl);
1258+
if (!ShapedType::isDynamic(shape[dim])) {
1259+
return getIndexAttr(shape[dim]);
1260+
}
1261+
return {};
1262+
}
1263+
1264+
// Non-permutation dim2lvl/lvl2dim maps.
1265+
AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
1266+
if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
1267+
if (lvlExpr.getKind() == AffineExprKind::Mod) {
1268+
// j % block_sz, the level size equals to the block size.
1269+
int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
1270+
return getIndexAttr(lvlSz);
1271+
}
1272+
if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
1273+
// j / block_sz, the level size equals to dim[j] / block_sz.
1274+
Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
1275+
int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
1276+
if (ShapedType::isDynamic(shape[dim]))
1277+
return {};
1278+
return getIndexAttr(shape[dim] / blockSz);
1279+
}
1280+
}
1281+
1282+
auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
1283+
if (!ShapedType::isDynamic(dim))
1284+
return getIndexAttr(shape[dim]);
1285+
1286+
return {};
1287+
}
1288+
12111289
LogicalResult ToPositionsOp::verify() {
12121290
auto e = getSparseTensorEncoding(getTensor().getType());
12131291
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1639,6 +1717,16 @@ LogicalResult YieldOp::verify() {
16391717
// TensorDialect Methods.
16401718
//===----------------------------------------------------------------------===//
16411719

1720+
/// Materialize a single constant operation from a given attribute value with
1721+
/// the desired resultant type.
1722+
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
1723+
Attribute value, Type type,
1724+
Location loc) {
1725+
if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1726+
return op;
1727+
return nullptr;
1728+
}
1729+
16421730
void SparseTensorDialect::initialize() {
16431731
addAttributes<
16441732
#define GET_ATTRDEF_LIST

mlir/test/Dialect/SparseTensor/fold.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
9393
%d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
9494
return %d0, %d1 : index, index
9595
}
96+
97+
// CHECK-LABEL: func.func @sparse_lvl_0(
98+
// CHECK: %[[C5:.*]] = arith.constant 5 : index
99+
// CHECK: return %[[C5]] : index
100+
func.func @sparse_lvl_0(%t : tensor<10x?xi32, #BSR>) -> index {
101+
%lvl = arith.constant 0 : index
102+
%l0 = sparse_tensor.lvl %t, %lvl : tensor<10x?xi32, #BSR>
103+
return %l0 : index
104+
}
105+
106+
// CHECK-LABEL: func.func @sparse_lvl_3(
107+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
108+
// CHECK: return %[[C3]] : index
109+
func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
110+
%lvl = arith.constant 3 : index
111+
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
112+
return %l0 : index
113+
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,3 +895,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (in
895895
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
896896
return %l0, %l1, %l2, %l3 : index, index, index, index
897897
}
898+
899+
// -----
900+
901+
#BSR = #sparse_tensor.encoding<{
902+
map = ( i, j ) ->
903+
( i floordiv 2 : dense,
904+
j floordiv 3 : compressed,
905+
i mod 2 : dense,
906+
j mod 3 : dense
907+
)
908+
}>
909+
910+
func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
911+
%lvl = arith.constant 5 : index
912+
// expected-error@+1 {{Level index exceeds the rank of the input sparse tensor}}
913+
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
914+
return %l0 : index
915+
}

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,24 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, in
669669
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
670670
return %l0, %l1, %l2, %l3 : index, index, index, index
671671
}
672+
673+
// -----
674+
675+
#BSR = #sparse_tensor.encoding<{
676+
map = ( i, j ) ->
677+
( i floordiv 2 : dense,
678+
j floordiv 3 : compressed,
679+
i mod 2 : dense,
680+
j mod 3 : dense
681+
)
682+
}>
683+
684+
// CHECK-LABEL: func.func @sparse_lvl(
685+
// CHECK-SAME: %[[VAL_0:.*]]: index,
686+
// CHECK-SAME: %[[VAL_1:.*]]: tensor
687+
// CHECK: %[[VAL_2:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_0]]
688+
// CHECK: return %[[VAL_2]]
689+
func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
690+
%l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
691+
return %l0 : index
692+
}

0 commit comments

Comments
 (0)