Skip to content

Commit 849529b

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix performance bug in matmul with a sparse rhs due to suboptimal iteration graphs.
While dense tensors support random accesses, it is critical to visit them in a row-major order for better cache locality. However, we previously consider dense inputs and outputs together when computing constraints for building iteration graph, it could lead us to less efficient iteration graphs. This patch adds a new `SortMask::kIncludeDenseInput` to treat dense inputs/outputs separately when building iteration graph, thus increasing the chance for use to construct a better iteration graph. A more fine-grained approach is to treat each input separately. Note, related to: #51651 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144932
1 parent 35d17c1 commit 849529b

File tree

4 files changed

+219
-144
lines changed

4 files changed

+219
-144
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,29 @@ namespace {
4747

4848
/// Iteration graph sorting.
4949
enum SortMask {
50-
kSparseOnly = 0x0,
51-
kIncludeDense = 0x1,
52-
kIncludeUndef = 0x2,
53-
kIncludeAll = 0x3
50+
// The individual mask bits.
51+
kIncludeDenseOutput = 0x1, // b001
52+
kIncludeDenseInput = 0x2, // b010
53+
kIncludeUndef = 0x4, // b100
54+
// The subsets of mask bits.
55+
kIncludeAll = 0x7, // b111
56+
kIncludeDense = 0x3, // b011
57+
kSparseOnly = 0x0, // b000
5458
};
5559

60+
/// SortMask tests on individual bits.
61+
inline static bool includeDenseInput(unsigned mask) {
62+
return mask & SortMask::kIncludeDenseInput;
63+
}
64+
65+
inline static bool includeDenseOutput(unsigned mask) {
66+
return mask & SortMask::kIncludeDenseOutput;
67+
}
68+
69+
inline static bool includeUndef(unsigned mask) {
70+
return mask & SortMask::kIncludeUndef;
71+
}
72+
5673
/// A helper class that visits an affine expression and tries to find an
5774
/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
5875
/// the desired iterator type.
@@ -453,9 +470,35 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
453470
const auto map = env.op().getMatchingIndexingMap(&t);
454471
const auto enc = getSparseTensorEncoding(t.get().getType());
455472
assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n);
456-
// Skip dense tensor constraints when not requested.
457-
if (!(mask & SortMask::kIncludeDense) && !enc)
473+
474+
bool isDenseInput = !enc && env.op().isDpsInput(&t);
475+
bool isDenseOutput = !enc && !isDenseInput;
476+
477+
// Skips dense inputs/outputs when not requested.
478+
if ((isDenseInput && !includeDenseInput(mask)) ||
479+
(isDenseOutput && !includeDenseOutput(mask)))
458480
continue;
481+
482+
// Push unrelated loops into sparse iteration space, so these
483+
// will be skipped more often.
484+
// TODO: Do we really need this?
485+
if (includeUndef(mask)) {
486+
unsigned tensor = t.getOperandNumber();
487+
for (unsigned i = 0; i < n; i++) {
488+
if (isCompressedDLT(env.dlt(tensor, i)) ||
489+
isSingletonDLT(env.dlt(tensor, i))) {
490+
for (unsigned j = 0; j < n; j++)
491+
if (isUndefDLT(env.dlt(tensor, j))) {
492+
adjM[i][j] = true;
493+
inDegree[j]++;
494+
}
495+
} else {
496+
assert(isDenseDLT(env.dlt(tensor, i)) ||
497+
isUndefDLT(env.dlt(tensor, i)));
498+
}
499+
}
500+
}
501+
459502
// Each tensor expression and optional dimension ordering (row-major
460503
// by default) puts an ordering constraint on the loop indices. For
461504
// example, the tensor expresion A_ijk forces the ordering i < j < k
@@ -508,24 +551,6 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
508551
addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
509552
}
510553
}
511-
// Push unrelated loops into sparse iteration space, so these
512-
// will be skipped more often.
513-
if (mask & SortMask::kIncludeUndef) {
514-
unsigned tensor = t.getOperandNumber();
515-
for (unsigned i = 0; i < n; i++) {
516-
if (isCompressedDLT(env.dlt(tensor, i)) ||
517-
isSingletonDLT(env.dlt(tensor, i))) {
518-
for (unsigned j = 0; j < n; j++)
519-
if (isUndefDLT(env.dlt(tensor, j))) {
520-
adjM[i][j] = true;
521-
inDegree[j]++;
522-
}
523-
} else {
524-
assert(isDenseDLT(env.dlt(tensor, i)) ||
525-
isUndefDLT(env.dlt(tensor, i)));
526-
}
527-
}
528-
}
529554
}
530555
// Topologically sort the iteration graph to determine loop order.
531556
// Report failure for a cyclic iteration graph.
@@ -1532,8 +1557,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
15321557

15331558
// An const list of all masks that we used for interation graph
15341559
// computation. Must be ordered from more strict to less strict.
1535-
const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
1536-
SortMask::kIncludeDense, SortMask::kSparseOnly};
1560+
// Ideally (though might not be guaranteed), the eariler a constraint mask
1561+
// can be satisfied, the faster the generated kernel will be.
1562+
const auto allMask = {
1563+
SortMask::kIncludeAll, SortMask::kIncludeDense,
1564+
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
1565+
SortMask::kIncludeUndef, SortMask::kSparseOnly};
15371566
for (auto mask : allMask) {
15381567
if (computeIterationGraph(env, mask)) {
15391568
hasCycle = false;

mlir/test/Dialect/SparseTensor/sparse_2d.mlir

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,46 +1002,45 @@ func.func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor
10021002
doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
10031003
}
10041004

1005-
// CHECK-LABEL: func @sampled_dense_dense(
1005+
// CHECK-LABEL: func.func @sampled_dense_dense(
10061006
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
10071007
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf32>,
10081008
// CHECK-SAME: %[[VAL_2:.*2]]: tensor<?x?xf32>,
10091009
// CHECK-SAME: %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
1010-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
1011-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
1010+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
1011+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
10121012
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
10131013
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
10141014
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
10151015
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
10161016
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf32>
1017-
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf32>
1018-
// CHECK-DAG: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?x?xf32>
1017+
// CHECK-DAG: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf32>
1018+
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf32>
10191019
// CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
1020-
// CHECK-DAG: %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
1021-
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
1022-
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
1023-
// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
1024-
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
1025-
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
1026-
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_20]], %[[VAL_5]] : index
1027-
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1028-
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
1029-
// CHECK-DAG: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
1030-
// CHECK-DAG: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
1031-
// CHECK-DAG: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
1032-
// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
1033-
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
1034-
// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
1035-
// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_32]], %[[VAL_33]] : f32
1036-
// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_27]], %[[VAL_34]] : f32
1037-
// CHECK: %[[VAL_36:.*]] = arith.addf %[[VAL_31]], %[[VAL_35]] : f32
1038-
// CHECK: scf.yield %[[VAL_36]] : f32
1039-
// CHECK: }
1040-
// CHECK: memref.store %[[VAL_29]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
1041-
// CHECK: }
1042-
// CHECK: }
1043-
// CHECK: %[[VAL_38:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<?x?xf32>
1044-
// CHECK: return %[[VAL_38]] : tensor<?x?xf32>
1020+
// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
1021+
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
1022+
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
1023+
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_4]] {
1024+
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
1025+
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_4]] {
1026+
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_19]]] : memref<?x?xf32>
1027+
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
1028+
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_17]], %[[VAL_4]] : index
1029+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_22]]] : memref<?xindex>
1030+
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_4]] {
1031+
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref<?xindex>
1032+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<?x?xf32>
1033+
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xf32>
1034+
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]], %[[VAL_25]]] : memref<?x?xf32>
1035+
// CHECK: %[[VAL_29:.*]] = arith.mulf %[[VAL_20]], %[[VAL_28]] : f32
1036+
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_27]], %[[VAL_29]] : f32
1037+
// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_26]], %[[VAL_30]] : f32
1038+
// CHECK: memref.store %[[VAL_31]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<?x?xf32>
1039+
// CHECK: } {"Emitted from" = "linalg.generic"}
1040+
// CHECK: } {"Emitted from" = "linalg.generic"}
1041+
// CHECK: } {"Emitted from" = "linalg.generic"}
1042+
// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<?x?xf32>
1043+
// CHECK: return %[[VAL_32]] : tensor<?x?xf32>
10451044
// CHECK: }
10461045
func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
10471046
%arga: tensor<?x?xf32>,

0 commit comments

Comments
 (0)