Skip to content

Commit faa75f9

Browse files
author
Peiming Liu
committed
[mlir][sparse] reject kernels with non-sparsfiable reduction expression.
To address #59394. Reduction on negation of the output tensor is a non-sparsifiable kernel, it creates cyclic dependency. This patch reject those cases instead of crashing. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D139659
1 parent 4efcea9 commit faa75f9

File tree

4 files changed

+192
-4
lines changed

4 files changed

+192
-4
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,15 @@ class Merger {
271271
return ldx >= numNativeLoops;
272272
}
273273

274+
/// Returns true if the expression contains the `t` as an operand.
275+
bool expContainsTensor(unsigned e, unsigned t) const;
276+
277+
/// Returns true if the expression contains a negation on output tensor.
278+
/// I.e., `- outTensor` or `exp - outputTensor`
279+
/// NOTE: this is an trivial tests in that it does not handle recursive
280+
/// negation, i.e., it returns true when the expression is `-(-tensor)`.
281+
bool hasNegateOnOut(unsigned e) const;
282+
274283
/// Returns true if given tensor iterates *only* in the given tensor
275284
/// expression. For the output tensor, this defines a "simply dynamic"
276285
/// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
@@ -348,9 +357,9 @@ class Merger {
348357
void dumpBits(const BitVector &bits) const;
349358
#endif
350359

351-
/// Builds the iteration lattices in a bottom-up traversal given the remaining
352-
/// tensor (sub)expression and the next loop index in the iteration graph.
353-
/// Returns index of the root expression.
360+
/// Builds the iteration lattices in a bottom-up traversal given the
361+
/// remaining tensor (sub)expression and the next loop index in the
362+
/// iteration graph. Returns index of the root expression.
354363
unsigned buildLattices(unsigned e, unsigned i);
355364

356365
/// Builds a tensor expression from the given Linalg operation.
@@ -380,7 +389,8 @@ class Merger {
380389
// Map that converts pair<tensor id, loop id> to the corresponding dimension
381390
// level type.
382391
std::vector<std::vector<DimLevelType>> dimTypes;
383-
// Map that converts pair<tensor id, loop id> to the corresponding dimension.
392+
// Map that converts pair<tensor id, loop id> to the corresponding
393+
// dimension.
384394
std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
385395
// Map that converts pair<tensor id, dim> to the corresponding loop id.
386396
std::vector<std::vector<Optional<unsigned>>> dimToLoopIdx;

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,19 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
583583
std::vector<unsigned> &topSort, unsigned exp,
584584
OpOperand **sparseOut,
585585
unsigned &outerParNest) {
586+
// We reject any expression that makes a reduction from `-outTensor`, as those
587+
// expression create dependency between the current iteration (i) and the
588+
// previous iteration (i-1). It would then require iterating over the whole
589+
// coordinate space, which prevent us from exploiting sparsity for faster
590+
// code.
591+
for (utils::IteratorType it : op.getIteratorTypesArray()) {
592+
if (it == utils::IteratorType::reduction) {
593+
if (merger.hasNegateOnOut(exp))
594+
return false;
595+
break;
596+
}
597+
}
598+
586599
OpOperand *lhs = op.getDpsInitOperand(0);
587600
unsigned tensor = lhs->getOperandNumber();
588601
auto enc = getSparseTensorEncoding(lhs->get().getType());

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,81 @@
1818
namespace mlir {
1919
namespace sparse_tensor {
2020

21+
enum class ExpArity {
22+
kNullary,
23+
kUnary,
24+
kBinary,
25+
};
26+
27+
static ExpArity getExpArity(Kind k) {
28+
switch (k) {
29+
// Leaf.
30+
case kTensor:
31+
case kInvariant:
32+
case kIndex:
33+
return ExpArity::kNullary;
34+
case kAbsF:
35+
case kAbsC:
36+
case kAbsI:
37+
case kCeilF:
38+
case kFloorF:
39+
case kSqrtF:
40+
case kSqrtC:
41+
case kExpm1F:
42+
case kExpm1C:
43+
case kLog1pF:
44+
case kLog1pC:
45+
case kSinF:
46+
case kSinC:
47+
case kTanhF:
48+
case kTanhC:
49+
case kTruncF:
50+
case kExtF:
51+
case kCastFS:
52+
case kCastFU:
53+
case kCastSF:
54+
case kCastUF:
55+
case kCastS:
56+
case kCastU:
57+
case kCastIdx:
58+
case kTruncI:
59+
case kCIm:
60+
case kCRe:
61+
case kBitCast:
62+
case kBinaryBranch:
63+
case kUnary:
64+
case kSelect:
65+
case kNegF:
66+
case kNegC:
67+
case kNegI:
68+
return ExpArity::kUnary;
69+
// Binary operations.
70+
case kDivF:
71+
case kDivC:
72+
case kDivS:
73+
case kDivU:
74+
case kShrS:
75+
case kShrU:
76+
case kShlI:
77+
case kMulF:
78+
case kMulC:
79+
case kMulI:
80+
case kAndI:
81+
case kAddF:
82+
case kAddC:
83+
case kAddI:
84+
case kOrI:
85+
case kXorI:
86+
case kBinary:
87+
case kReduce:
88+
case kSubF:
89+
case kSubC:
90+
case kSubI:
91+
return ExpArity::kBinary;
92+
}
93+
llvm_unreachable("unexpected kind");
94+
}
95+
2196
//===----------------------------------------------------------------------===//
2297
// Constructors.
2398
//===----------------------------------------------------------------------===//
@@ -310,6 +385,57 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
310385
return !hasAnySparse(tmp);
311386
}
312387

388+
bool Merger::expContainsTensor(unsigned e, unsigned t) const {
389+
if (tensorExps[e].kind == kTensor)
390+
return tensorExps[e].tensor == t;
391+
392+
switch (getExpArity(tensorExps[e].kind)) {
393+
case ExpArity::kNullary:
394+
return false;
395+
case ExpArity::kUnary: {
396+
unsigned op = tensorExps[e].children.e0;
397+
if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t)
398+
return true;
399+
return expContainsTensor(op, t);
400+
}
401+
case ExpArity::kBinary: {
402+
unsigned op1 = tensorExps[e].children.e0;
403+
unsigned op2 = tensorExps[e].children.e1;
404+
if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) ||
405+
(tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t))
406+
return true;
407+
return expContainsTensor(op1, t) || expContainsTensor(op2, t);
408+
}
409+
}
410+
llvm_unreachable("unexpected arity");
411+
}
412+
413+
bool Merger::hasNegateOnOut(unsigned e) const {
414+
switch (tensorExps[e].kind) {
415+
case kNegF:
416+
case kNegC:
417+
case kNegI:
418+
return expContainsTensor(tensorExps[e].children.e0, outTensor);
419+
case kSubF:
420+
case kSubC:
421+
case kSubI:
422+
return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
423+
hasNegateOnOut(tensorExps[e].children.e0);
424+
default: {
425+
switch (getExpArity(tensorExps[e].kind)) {
426+
case ExpArity::kNullary:
427+
return false;
428+
case ExpArity::kUnary:
429+
return hasNegateOnOut(tensorExps[e].children.e0);
430+
case ExpArity::kBinary:
431+
return hasNegateOnOut(tensorExps[e].children.e0) ||
432+
hasNegateOnOut(tensorExps[e].children.e1);
433+
}
434+
}
435+
}
436+
llvm_unreachable("unexpected kind");
437+
}
438+
313439
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
314440
switch (tensorExps[e].kind) {
315441
// Leaf.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt %s -sparsification | FileCheck %s
2+
3+
4+
// The file contains examples that will be rejected by sparse compiler
5+
// (we expect the linalg.generic unchanged).
6+
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
7+
8+
#trait = {
9+
indexing_maps = [
10+
affine_map<(i) -> (i)>, // a (in)
11+
affine_map<(i) -> ()> // x (out)
12+
],
13+
iterator_types = ["reduction"]
14+
}
15+
16+
// CHECK-LABEL: func.func @sparse_reduction_subi(
17+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<i32>,
18+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<i32> {
19+
// CHECK: %[[VAL_2:.*]] = linalg.generic
20+
// CHECK: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
21+
// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_3]], %[[VAL_4]] : i32
22+
// CHECK: linalg.yield %[[VAL_5]] : i32
23+
// CHECK: } -> tensor<i32>
24+
// CHECK: return %[[VAL_6:.*]] : tensor<i32>
25+
func.func @sparse_reduction_subi(%argx: tensor<i32>,
26+
%arga: tensor<?xi32, #SparseVector>)
27+
-> tensor<i32> {
28+
%0 = linalg.generic #trait
29+
ins(%arga: tensor<?xi32, #SparseVector>)
30+
outs(%argx: tensor<i32>) {
31+
^bb(%a: i32, %x: i32):
32+
// NOTE: `subi %a, %x` is the reason why the program is rejected by the sparse compiler.
33+
// It is because we do not allow `-outTensor` in reduction loops as it creates cyclic
34+
// dependences.
35+
%t = arith.subi %a, %x: i32
36+
linalg.yield %t : i32
37+
} -> tensor<i32>
38+
return %0 : tensor<i32>
39+
}

0 commit comments

Comments
 (0)