Skip to content

Commit 5508516

Browse files
committed
[mlir][sparse] retry sparse-only for cyclic iteration graphs
This is a very minor improvement during iteration graph construction. If the first attempt considering the dimension order of all tensors fails, a second attempt is made using the constraints of sparse tensors only. Dense tensors prefer dimension order (locality) but provide random access if needed, enabling the compilation of more sparse kernels. Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D94709
1 parent 39665d9 commit 5508516

File tree

2 files changed

+111
-7
lines changed

2 files changed

+111
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ class Merger {
274274
return false;
275275
}
276276

277+
// Returns true if tensor has any sparse dimension.
278+
bool isSparseTensor(unsigned t) const {
279+
return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; });
280+
}
281+
277282
// Setter
278283
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
279284

@@ -382,17 +387,22 @@ static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
382387
/// for sparse storage formats since these only support access along fixed
383388
/// dimensions. Even for dense storage formats, however, the natural index
384389
/// order yields innermost unit-stride access with better spatial locality.
385-
static bool computeIterationGraph(linalg::GenericOp op,
386-
std::vector<unsigned> &topSort) {
390+
static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
391+
std::vector<unsigned> &topSort,
392+
bool sparseOnly) {
387393
// Set up an n x n from/to adjacency matrix of the iteration graph
388394
// for the implicit loop indices i_0 .. i_n-1.
389395
unsigned n = op.getNumLoops();
390396
std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
391397

392398
// Iterate over the indexing maps of every tensor in the tensor expression.
393-
for (auto imap : llvm::enumerate(op.indexing_maps())) {
394-
auto map = imap.value().template cast<AffineMapAttr>().getValue();
399+
unsigned numTensors = op.getNumShapedOperands();
400+
for (unsigned t = 0; t < numTensors; t++) {
401+
auto map = op.getIndexingMap(t);
395402
assert(map.getNumDims() == n);
403+
// Skip dense tensor constraints when sparse only is requested.
404+
if (sparseOnly && !merger.isSparseTensor(t))
405+
continue;
396406
// At the moment, we take the index variables in the tensor access
397407
// expression in the order in which they appear (conceptually a
398408
// "row-major" layout of every tensor). So, a tensor access A_ijk
@@ -407,6 +417,7 @@ static bool computeIterationGraph(linalg::GenericOp op,
407417

408418
// Topologically sort the iteration graph to determine loop order.
409419
// Report failure for a cyclic iteration graph.
420+
topSort.clear();
410421
topSort.reserve(n);
411422
std::vector<unsigned> visit(n, 0);
412423
for (unsigned i = 0; i < n; i++)
@@ -1207,10 +1218,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
12071218
// tensors are visited in natural index order. Fails on cycles.
12081219
// This assumes that higher-level passes have already put the
12091220
// tensors in each tensor expression in a feasible order.
1210-
// TODO: try again without *dense* constraints on failure or
1211-
// even try to insert sparse reorderings to resolve cycles
12121221
std::vector<unsigned> topSort;
1213-
if (!computeIterationGraph(op, topSort))
1222+
if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1223+
!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
12141224
return failure();
12151225

12161226
// Finds the terminating yield statement and builds the tensor
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2+
// RUN: mlir-opt %s -test-sparsification | FileCheck %s
3+
4+
// Example with cyclic iteration graph with sparse and dense constraints,
5+
// but an acyclic iteration graph using sparse constraints only.
6+
#trait_mul = {
7+
indexing_maps = [
8+
affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>, // A
9+
affine_map<(i,j,k,l,m,n,o,p) -> (p,o,n,m,l,k,j,i)>, // B
10+
affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)> // X
11+
],
12+
sparse = [
13+
[ "D", "D", "D", "D", "D", "D", "D", "D" ], // a
14+
[ "D", "D", "D", "S", "S", "D", "D", "D" ], // b
15+
[ "D", "D", "D", "D", "D", "D", "D", "D" ] // x
16+
],
17+
iterator_types = ["parallel", "parallel", "parallel", "parallel",
18+
"parallel", "parallel", "parallel", "parallel"],
19+
doc = "X(i,j,k,l,m,n,o,p) = A(i,j,k,l,m,n,o,p) * B(p,o,n,m,l,k,j,i)"
20+
}
21+
22+
// CHECK-LABEL: func @mul(
23+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<100x200x300x400x500x600x700x800xf32>,
24+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<100x200x300x400x500x600x700x800xf32>) -> tensor<100x200x300x400x500x600x700x800xf32> {
25+
// CHECK: %[[VAL_2:.*]] = constant 999 : index
26+
// CHECK: %[[VAL_3:.*]] = constant 100 : index
27+
// CHECK: %[[VAL_4:.*]] = constant 200 : index
28+
// CHECK: %[[VAL_5:.*]] = constant 300 : index
29+
// CHECK: %[[VAL_6:.*]] = constant 600 : index
30+
// CHECK: %[[VAL_7:.*]] = constant 700 : index
31+
// CHECK: %[[VAL_8:.*]] = constant 800 : index
32+
// CHECK: %[[VAL_9:.*]] = constant 0 : index
33+
// CHECK: %[[VAL_10:.*]] = constant 1 : index
34+
// CHECK: %[[VAL_11:.*]] = alloca() : memref<100x200x300x400x500x600x700x800xf32>
35+
// CHECK: %[[VAL_12:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
36+
// CHECK: %[[VAL_13:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
37+
// CHECK: %[[VAL_14:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
38+
// CHECK: %[[VAL_15:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
39+
// CHECK: %[[VAL_16:.*]] = alloca(%[[VAL_2]]) : memref<?xf32>
40+
// CHECK: %[[VAL_17:.*]] = alloca() : memref<100x200x300x400x500x600x700x800xf32>
41+
// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_10]] {
42+
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_9]] to %[[VAL_7]] step %[[VAL_10]] {
43+
// CHECK: %[[VAL_20:.*]] = muli %[[VAL_18]], %[[VAL_7]] : index
44+
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_20]], %[[VAL_19]] : index
45+
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_10]] {
46+
// CHECK: %[[VAL_23:.*]] = muli %[[VAL_21]], %[[VAL_6]] : index
47+
// CHECK: %[[VAL_24:.*]] = addi %[[VAL_23]], %[[VAL_22]] : index
48+
// CHECK: %[[VAL_25:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
49+
// CHECK: %[[VAL_26:.*]] = addi %[[VAL_24]], %[[VAL_10]] : index
50+
// CHECK: %[[VAL_27:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<?xindex>
51+
// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_25]] to %[[VAL_27]] step %[[VAL_10]] {
52+
// CHECK: %[[VAL_29:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_28]]] : memref<?xindex>
53+
// CHECK: %[[VAL_30:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex>
54+
// CHECK: %[[VAL_31:.*]] = addi %[[VAL_28]], %[[VAL_10]] : index
55+
// CHECK: %[[VAL_32:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref<?xindex>
56+
// CHECK: scf.for %[[VAL_33:.*]] = %[[VAL_30]] to %[[VAL_32]] step %[[VAL_10]] {
57+
// CHECK: %[[VAL_34:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_33]]] : memref<?xindex>
58+
// CHECK: scf.for %[[VAL_35:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_10]] {
59+
// CHECK: %[[VAL_36:.*]] = muli %[[VAL_33]], %[[VAL_5]] : index
60+
// CHECK: %[[VAL_37:.*]] = addi %[[VAL_36]], %[[VAL_35]] : index
61+
// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_10]] {
62+
// CHECK: %[[VAL_39:.*]] = muli %[[VAL_37]], %[[VAL_4]] : index
63+
// CHECK: %[[VAL_40:.*]] = addi %[[VAL_39]], %[[VAL_38]] : index
64+
// CHECK: scf.for %[[VAL_41:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_10]] {
65+
// CHECK: %[[VAL_42:.*]] = muli %[[VAL_40]], %[[VAL_3]] : index
66+
// CHECK: %[[VAL_43:.*]] = addi %[[VAL_42]], %[[VAL_41]] : index
67+
// CHECK: %[[VAL_44:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_41]], %[[VAL_38]], %[[VAL_35]], %[[VAL_34]], %[[VAL_29]], %[[VAL_22]], %[[VAL_19]], %[[VAL_18]]] : memref<100x200x300x400x500x600x700x800xf32>
68+
// CHECK: %[[VAL_45:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_43]]] : memref<?xf32>
69+
// CHECK: %[[VAL_46:.*]] = mulf %[[VAL_44]], %[[VAL_45]] : f32
70+
// CHECK: store %[[VAL_46]], %[[VAL_17]]{{\[}}%[[VAL_41]], %[[VAL_38]], %[[VAL_35]], %[[VAL_34]], %[[VAL_29]], %[[VAL_22]], %[[VAL_19]], %[[VAL_18]]] : memref<100x200x300x400x500x600x700x800xf32>
71+
// CHECK: }
72+
// CHECK: }
73+
// CHECK: }
74+
// CHECK: }
75+
// CHECK: }
76+
// CHECK: }
77+
// CHECK: }
78+
// CHECK: }
79+
// CHECK: %[[VAL_47:.*]] = tensor_load %[[VAL_17]] : memref<100x200x300x400x500x600x700x800xf32>
80+
// CHECK: return %[[VAL_47]] : tensor<100x200x300x400x500x600x700x800xf32>
81+
// CHECK: }
82+
func @mul(%arga: tensor<100x200x300x400x500x600x700x800xf32>,
83+
%argb: tensor<100x200x300x400x500x600x700x800xf32>)
84+
-> tensor<100x200x300x400x500x600x700x800xf32> {
85+
%0 = linalg.generic #trait_mul
86+
ins(%arga, %argb: tensor<100x200x300x400x500x600x700x800xf32>,
87+
tensor<100x200x300x400x500x600x700x800xf32>)
88+
outs(%arga: tensor<100x200x300x400x500x600x700x800xf32>) {
89+
^bb(%a: f32, %b: f32, %s : f32):
90+
%0 = mulf %a, %b : f32
91+
linalg.yield %0 : f32
92+
} -> tensor<100x200x300x400x500x600x700x800xf32>
93+
return %0 : tensor<100x200x300x400x500x600x700x800xf32>
94+
}

0 commit comments

Comments
 (0)