Skip to content

[mlir][sparse] introduce sparse_tensor.crd_translate operation #69630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
"SparseTensorSortAlgorithm"> {
}


//===----------------------------------------------------------------------===//
// Sparse Tensor Coordinate Translation Direction Attribute.
//===----------------------------------------------------------------------===//

// The C++ enum for sparse tensor sort kind.
def SparseTensorCrdTransDirectionEnum
: I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = SparseTensor_Dialect.cppNamespace;
}

// Define the enum sparse tensor sort kind attribute.
def SparseTensorCrdTransDirectionAttr
: EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
"CrdTransDirection"> {
}


#endif // SPARSETENSOR_ATTRDEFS
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,25 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Coordinate Translation Operation.
//===----------------------------------------------------------------------===//

def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
Arguments<(ins Variadic<Index>:$in_crds,
SparseTensorCrdTransDirectionAttr:$direction,
SparseTensorEncodingAttr:$oracle)>,
Results<(outs Variadic<Index>:$out_crds)> {
string summary = "Performs coordinate translation between level and dimension coordinate space.";
string description = [{
Performs coordinate translation between level and dimension coordinate space according
to the provided affine maps.
}];
let assemblyFormat = "$direction `[` $in_crds `]` `as` $oracle attr-dict `:` type($out_crds)";
let hasVerifier = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
// sense that some behavior is defined by side-effects. These operations provide
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,60 @@ bool ConvertOp::needsExtraSort() {
return true;
}

LogicalResult CrdTranslateOp::verify() {
uint64_t inRank = getOracle().getLvlRank();
uint64_t outRank = getOracle().getDimRank();

if (getDirection() == CrdTransDirectionKind::dim2lvl)
std::swap(inRank, outRank);

if (inRank != getInCrds().size() || outRank != getOutCrds().size())
return emitError("Coordinate rank mismatch with encoding");

return success();
}

LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (getOracle().isPermutation()) {
AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
? getOracle().getDimToLvl()
: getOracle().getLvlToDim();
for (AffineExpr exp : perm.getResults())
results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
return success();
}

// Fuse dim2lvl/lvl2dim pairs.
auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
return v.getDefiningOp() == def;
});
if (!sameDef)
return failure();

bool oppositeDir = def.getDirection() != getDirection();
bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
bool sameCount = def.getNumResults() == getInCrds().size();
if (!oppositeDir || !sameOracle || !sameCount)
return failure();

// The definition produce the coordinate in the same order as the input
// coordinates.
bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
[](auto valuePair) {
auto [lhs, rhs] = valuePair;
return lhs == rhs;
});

if (!sameOrder)
return failure();
// l1 = dim2lvl (lvl2dim l0)
// ==> l0
results.append(def.getInCrds().begin(), def.getInCrds().end());
return success();
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/SparseTensor/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
return %ret : tensor<?x?xf32, #COO>
}


#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

// CHECK-LABEL: func @sparse_crd_translate(
// CHECK-NOT: sparse_tensor.crd_translate
func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
%d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
return %d0, %d1 : index, index
}
34 changes: 34 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
return %ret : tensor<?x?xf64, #OrderedCOO>
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
// expected-error@+1 {{Coordinate rank mismatch with encoding}}
%l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
return %l0, %l1, %l2 : index, index, index
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
// expected-error@+1 {{Coordinate rank mismatch with encoding}}
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
return %ret : tensor<?x?xf32, #OrderedCOO>
}


// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

// CHECK-LABEL: func.func @sparse_crd_translate(
// CHECK-SAME: %[[VAL_0:.*]]: index,
// CHECK-SAME: %[[VAL_1:.*]]: index)
// CHECK: %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
// CHECK: return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}