Skip to content

Commit b228e2b

Browse files
committed
[mlir][sparse] generalize invariant expression handling in sparse compiler
Generalizes invariant handling to anything defined outside the Linalg op (parameters and SSA computations). Fixes bug that was using parameter number as tensor number. Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D91985
1 parent 4f5355e commit b228e2b

File tree

2 files changed

+74
-12
lines changed

2 files changed

+74
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,16 @@ namespace {
5454
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
5555

5656
/// Tensor expression. Represents a MLIR expression in tensor index notation.
57-
/// For tensors and invariants, e0 denotes the tensor index. For all binary
58-
/// operations, e0 and e1 denote the index of the children tensor expressions.
57+
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
58+
/// stored directly. For binary operations, e0 and e1 denote the index of the
59+
/// children tensor expressions.
5960
struct TensorExp {
60-
TensorExp(Kind k, unsigned x, unsigned y) : kind(k), e0(x), e1(y) {}
61+
TensorExp(Kind k, unsigned x, unsigned y, Value v)
62+
: kind(k), e0(x), e1(y), val(v) {}
6163
Kind kind;
6264
unsigned e0;
6365
unsigned e1;
66+
Value val;
6467
};
6568

6669
/// Lattice point. Each lattice point consist of a conjunction of tensor
@@ -85,11 +88,12 @@ class Merger {
8588
: numTensors(t), numLoops(l), isSparse(t, std::vector<bool>(l, false)) {}
8689

8790
/// Adds a tensor expression. Returns its index.
88-
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u) {
91+
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
8992
unsigned e = tensorExps.size();
90-
tensorExps.push_back(TensorExp(k, e0, e1));
93+
tensorExps.push_back(TensorExp(k, e0, e1, v));
9194
return e;
9295
}
96+
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
9397

9498
/// Adds an iteration lattice point. Returns its index.
9599
unsigned addLat(unsigned t, unsigned i, unsigned e) {
@@ -339,7 +343,6 @@ static bool computeIterationGraph(linalg::GenericOp op,
339343
/// building (compared to using the SSA representation everywhere).
340344
static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
341345
Value val) {
342-
Operation *def = val.getDefiningOp();
343346
if (auto arg = val.dyn_cast<BlockArgument>()) {
344347
unsigned argN = arg.getArgNumber();
345348
if (arg.getOwner()->getParentOp() == op) {
@@ -348,10 +351,16 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
348351
auto map = op.getIndexingMap(argN);
349352
if (map.isProjectedPermutation())
350353
return merger.addExp(Kind::kTensor, argN);
351-
} else {
352-
// Any parameter of a higher op is invariant in the tensor expression.
353-
return merger.addExp(Kind::kInvariant, argN);
354+
// Cannot handle (yet).
355+
return None;
354356
}
357+
// Any parameter of a higher op is invariant.
358+
return merger.addExp(Kind::kInvariant, val);
359+
}
360+
Operation *def = val.getDefiningOp();
361+
if (def->getBlock() != &op.region().front()) {
362+
// Something defined outside is invariant.
363+
return merger.addExp(Kind::kInvariant, val);
355364
} else if (def->getNumOperands() == 2) {
356365
// Construct binary operations if subexpressions could be built.
357366
auto x = buildTensorExp(merger, op, def->getOperand(0));
@@ -380,9 +389,12 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
380389
Kind kind = merger.exp(exp).kind;
381390
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
382391
// Either the index is really used in the tensor expression, or it it
383-
// set to the "non-existing dense index" in that dimension.
392+
// set to the "non-existing dense index" in that dimension. Invariant
393+
// expressions borrow the output tensor indices.
384394
unsigned s = merger.addSet();
385-
merger.set(s).push_back(merger.addLat(merger.exp(exp).e0, idx, exp));
395+
unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
396+
: op.getNumInputsAndOutputs() - 1;
397+
merger.set(s).push_back(merger.addLat(t, idx, exp));
386398
return s;
387399
}
388400
unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
@@ -502,7 +514,7 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
502514
if (merger.exp(exp).kind == Kind::kTensor)
503515
return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0);
504516
else if (merger.exp(exp).kind == Kind::kInvariant)
505-
return op.getParentRegion()->front().getArgument(merger.exp(exp).e0);
517+
return merger.exp(exp).val;
506518
Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
507519
Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
508520
switch (merger.exp(exp).kind) {

mlir/test/Dialect/Linalg/sparse_2d.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,56 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
11061106
return %0 : tensor<f32>
11071107
}
11081108

1109+
#trait_scale = {
1110+
indexing_maps = [
1111+
affine_map<(i,j) -> (i,j)>, // A
1112+
affine_map<(i,j) -> (i,j)> // X (out)
1113+
],
1114+
sparse = [
1115+
[ "D", "S" ], // A
1116+
[ "D", "D" ] // X
1117+
],
1118+
iterator_types = ["parallel", "parallel"],
1119+
doc = "X(i,j) = A(i,j) * SCALE"
1120+
}
1121+
1122+
// CHECK-LABEL: func @scale(
1123+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64>) -> tensor<?x?xf64> {
1124+
// CHECK: %[[VAL_1:.*]] = constant 2.000000e+00 : f64
1125+
// CHECK: %[[VAL_2:.*]] = constant 999 : index
1126+
// CHECK: %[[VAL_3:.*]] = constant 0 : index
1127+
// CHECK: %[[VAL_4:.*]] = constant 1 : index
1128+
// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
1129+
// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
1130+
// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
1131+
// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
1132+
// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
1133+
// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref<?x?xf64>
1134+
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
1135+
// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
1136+
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index
1137+
// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
1138+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] {
1139+
// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
1140+
// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xf64>
1141+
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64
1142+
// CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<?x?xf64>
1143+
// CHECK: }
1144+
// CHECK: }
1145+
// CHECK: %[[VAL_19:.*]] = tensor_load %[[VAL_10]] : memref<?x?xf64>
1146+
// CHECK: return %[[VAL_19]] : tensor<?x?xf64>
1147+
// CHECK: }
1148+
func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
1149+
%0 = constant 2.0 : f64
1150+
%1 = linalg.generic #trait_scale
1151+
ins(%arga: tensor<?x?xf64>) {
1152+
^bb(%a: f64):
1153+
%2 = mulf %a, %0 : f64
1154+
linalg.yield %2 : f64
1155+
} -> tensor<?x?xf64>
1156+
return %1 : tensor<?x?xf64>
1157+
}
1158+
11091159
#trait_sampled_dense_dense = {
11101160
indexing_maps = [
11111161
affine_map<(i,j,k) -> (i,j)>, // S

0 commit comments

Comments
 (0)