Skip to content

Commit ff21a90

Browse files
authored
[mlir][sparse] introduce sparse_tensor.crd_translate operation (llvm#69630)
1 parent f681852 commit ff21a90

File tree

6 files changed

+177
-0
lines changed

6 files changed

+177
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
541541
"SparseTensorSortAlgorithm"> {
542542
}
543543

544+
545+
//===----------------------------------------------------------------------===//
546+
// Sparse Tensor Coordinate Translation Direction Attribute.
547+
//===----------------------------------------------------------------------===//
548+
549+
// The C++ enum for sparse tensor coordinate translation direction enum.
550+
def SparseTensorCrdTransDirectionEnum
551+
: I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
552+
I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
553+
I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
554+
]> {
555+
let genSpecializedAttr = 0;
556+
let cppNamespace = SparseTensor_Dialect.cppNamespace;
557+
}
558+
559+
// The C++ enum for sparse tensor coordinate translation direction attribute.
560+
def SparseTensorCrdTransDirectionAttr
561+
: EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
562+
"CrdTransDirection"> {
563+
}
564+
565+
544566
#endif // SPARSETENSOR_ATTRDEFS

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,32 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
520520
let hasVerifier = 1;
521521
}
522522

523+
//===----------------------------------------------------------------------===//
524+
// Sparse Tensor Coordinate Translation Operation.
525+
//===----------------------------------------------------------------------===//
526+
527+
def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
528+
Arguments<(ins Variadic<Index>:$in_crds,
529+
SparseTensorCrdTransDirectionAttr:$direction,
530+
SparseTensorEncodingAttr:$encoder)>,
531+
Results<(outs Variadic<Index>:$out_crds)> {
532+
string summary = "Performs coordinate translation between level and dimension coordinate space.";
533+
string description = [{
534+
Performs coordinate translation between level and dimension coordinate space according
535+
to the affine maps defined by $encoder.
536+
537+
Example:
538+
539+
```mlir
540+
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%d0, %d1] as #BSR
541+
: index, index, index, index
542+
```
543+
}];
544+
let assemblyFormat = "$direction `[` $in_crds `]` `as` $encoder attr-dict `:` type($out_crds)";
545+
let hasVerifier = 1;
546+
let hasFolder = 1;
547+
}
548+
523549
//===----------------------------------------------------------------------===//
524550
// Sparse Tensor Management Operations. These operations are "impure" in the
525551
// sense that some behavior is defined by side-effects. These operations provide

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,61 @@ bool ConvertOp::needsExtraSort() {
11401140
return true;
11411141
}
11421142

1143+
LogicalResult CrdTranslateOp::verify() {
1144+
uint64_t inRank = getEncoder().getLvlRank();
1145+
uint64_t outRank = getEncoder().getDimRank();
1146+
1147+
if (getDirection() == CrdTransDirectionKind::dim2lvl)
1148+
std::swap(inRank, outRank);
1149+
1150+
if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1151+
return emitError("Coordinate rank mismatch with encoding");
1152+
1153+
return success();
1154+
}
1155+
1156+
LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1157+
SmallVectorImpl<OpFoldResult> &results) {
1158+
if (getEncoder().isPermutation()) {
1159+
AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1160+
? getEncoder().getDimToLvl()
1161+
: getEncoder().getLvlToDim();
1162+
for (AffineExpr exp : perm.getResults())
1163+
results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
1164+
return success();
1165+
}
1166+
1167+
// Fuse dim2lvl/lvl2dim pairs.
1168+
auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1169+
bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1170+
return v.getDefiningOp() == def;
1171+
});
1172+
if (!sameDef)
1173+
return failure();
1174+
1175+
bool oppositeDir = def.getDirection() != getDirection();
1176+
bool sameOracle =
1177+
def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1178+
bool sameCount = def.getNumResults() == getInCrds().size();
1179+
if (!oppositeDir || !sameOracle || !sameCount)
1180+
return failure();
1181+
1182+
// The definition produces the coordinates in the same order as the input
1183+
// coordinates.
1184+
bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1185+
[](auto valuePair) {
1186+
auto [lhs, rhs] = valuePair;
1187+
return lhs == rhs;
1188+
});
1189+
1190+
if (!sameOrder)
1191+
return failure();
1192+
// l1 = dim2lvl (lvl2dim l0)
1193+
// ==> l0
1194+
results.append(def.getInCrds().begin(), def.getInCrds().end());
1195+
return success();
1196+
}
1197+
11431198
LogicalResult ToPositionsOp::verify() {
11441199
auto e = getSparseTensorEncoding(getTensor().getType());
11451200
if (failed(lvlIsInBounds(getLevel(), getTensor())))

mlir/test/Dialect/SparseTensor/fold.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
7575
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
7676
return %ret : tensor<?x?xf32, #COO>
7777
}
78+
79+
80+
#BSR = #sparse_tensor.encoding<{
81+
map = ( i, j ) ->
82+
( i floordiv 2 : dense,
83+
j floordiv 3 : compressed,
84+
i mod 2 : dense,
85+
j mod 3 : dense
86+
)
87+
}>
88+
89+
// CHECK-LABEL: func @sparse_crd_translate(
90+
// CHECK-NOT: sparse_tensor.crd_translate
91+
func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
92+
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
93+
%d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
94+
return %d0, %d1 : index, index
95+
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
861861
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
862862
return %ret : tensor<?x?xf64, #OrderedCOO>
863863
}
864+
865+
// -----
866+
867+
#BSR = #sparse_tensor.encoding<{
868+
map = ( i, j ) ->
869+
( i floordiv 2 : dense,
870+
j floordiv 3 : compressed,
871+
i mod 2 : dense,
872+
j mod 3 : dense
873+
)
874+
}>
875+
876+
func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
877+
// expected-error@+1 {{Coordinate rank mismatch with encoding}}
878+
%l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
879+
return %l0, %l1, %l2 : index, index, index
880+
}
881+
882+
// -----
883+
884+
#BSR = #sparse_tensor.encoding<{
885+
map = ( i, j ) ->
886+
( i floordiv 2 : dense,
887+
j floordiv 3 : compressed,
888+
i mod 2 : dense,
889+
j mod 3 : dense
890+
)
891+
}>
892+
893+
func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
894+
// expected-error@+1 {{Coordinate rank mismatch with encoding}}
895+
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
896+
return %l0, %l1, %l2, %l3 : index, index, index, index
897+
}

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
647647
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
648648
return %ret : tensor<?x?xf32, #OrderedCOO>
649649
}
650+
651+
652+
// -----
653+
654+
#BSR = #sparse_tensor.encoding<{
655+
map = ( i, j ) ->
656+
( i floordiv 2 : dense,
657+
j floordiv 3 : compressed,
658+
i mod 2 : dense,
659+
j mod 3 : dense
660+
)
661+
}>
662+
663+
// CHECK-LABEL: func.func @sparse_crd_translate(
664+
// CHECK-SAME: %[[VAL_0:.*]]: index,
665+
// CHECK-SAME: %[[VAL_1:.*]]: index)
666+
// CHECK: %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
667+
// CHECK: return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
668+
func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
669+
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
670+
return %l0, %l1, %l2, %l3 : index, index, index, index
671+
}

0 commit comments

Comments
 (0)