@@ -54,13 +54,16 @@ namespace {
54
54
enum class Kind { kTensor , kInvariant , kMulF , kMulI , kAddF , kAddI };
55
55
56
56
// / Tensor expression. Represents a MLIR expression in tensor index notation.
57
- // / For tensors and invariants, e0 denotes the tensor index. For all binary
58
- // / operations, e0 and e1 denote the index of the children tensor expressions.
57
+ // / For tensors, e0 denotes the tensor index. For invariants, the IR value is
58
+ // / stored directly. For binary operations, e0 and e1 denote the index of the
59
+ // / children tensor expressions.
59
60
struct TensorExp {
60
- TensorExp (Kind k, unsigned x, unsigned y) : kind(k), e0 (x), e1 (y) {}
61
+ TensorExp (Kind k, unsigned x, unsigned y, Value v)
62
+ : kind(k), e0 (x), e1 (y), val(v) {}
61
63
Kind kind;
62
64
unsigned e0 ;
63
65
unsigned e1 ;
66
+ Value val;
64
67
};
65
68
66
69
// / Lattice point. Each lattice point consist of a conjunction of tensor
@@ -85,11 +88,12 @@ class Merger {
85
88
: numTensors(t), numLoops(l), isSparse(t, std::vector<bool >(l, false )) {}
86
89
87
90
// / Adds a tensor expression. Returns its index.
88
- unsigned addExp (Kind k, unsigned e0 , unsigned e1 = -1u ) {
91
+ unsigned addExp (Kind k, unsigned e0 , unsigned e1 = -1u , Value v = Value() ) {
89
92
unsigned e = tensorExps.size ();
90
- tensorExps.push_back (TensorExp (k, e0 , e1 ));
93
+ tensorExps.push_back (TensorExp (k, e0 , e1 , v ));
91
94
return e;
92
95
}
96
+ unsigned addExp (Kind k, Value v) { return addExp (k, -1u , -1u , v); }
93
97
94
98
// / Adds an iteration lattice point. Returns its index.
95
99
unsigned addLat (unsigned t, unsigned i, unsigned e) {
@@ -339,7 +343,6 @@ static bool computeIterationGraph(linalg::GenericOp op,
339
343
// / building (compared to using the SSA representation everywhere).
340
344
static Optional<unsigned > buildTensorExp (Merger &merger, linalg::GenericOp op,
341
345
Value val) {
342
- Operation *def = val.getDefiningOp ();
343
346
if (auto arg = val.dyn_cast <BlockArgument>()) {
344
347
unsigned argN = arg.getArgNumber ();
345
348
if (arg.getOwner ()->getParentOp () == op) {
@@ -348,10 +351,16 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
348
351
auto map = op.getIndexingMap (argN);
349
352
if (map.isProjectedPermutation ())
350
353
return merger.addExp (Kind::kTensor , argN);
351
- } else {
352
- // Any parameter of a higher op is invariant in the tensor expression.
353
- return merger.addExp (Kind::kInvariant , argN);
354
+ // Cannot handle (yet).
355
+ return None;
354
356
}
357
+ // Any parameter of a higher op is invariant.
358
+ return merger.addExp (Kind::kInvariant , val);
359
+ }
360
+ Operation *def = val.getDefiningOp ();
361
+ if (def->getBlock () != &op.region ().front ()) {
362
+ // Something defined outside is invariant.
363
+ return merger.addExp (Kind::kInvariant , val);
355
364
} else if (def->getNumOperands () == 2 ) {
356
365
// Construct binary operations if subexpressions could be built.
357
366
auto x = buildTensorExp (merger, op, def->getOperand (0 ));
@@ -380,9 +389,12 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
380
389
Kind kind = merger.exp (exp).kind ;
381
390
if (kind == Kind::kTensor || kind == Kind::kInvariant ) {
382
391
// Either the index is really used in the tensor expression, or it it
383
- // set to the "non-existing dense index" in that dimension.
392
+ // set to the "non-existing dense index" in that dimension. Invariant
393
+ // expressions borrow the output tensor indices.
384
394
unsigned s = merger.addSet ();
385
- merger.set (s).push_back (merger.addLat (merger.exp (exp).e0 , idx, exp));
395
+ unsigned t = kind == Kind::kTensor ? merger.exp (exp).e0
396
+ : op.getNumInputsAndOutputs () - 1 ;
397
+ merger.set (s).push_back (merger.addLat (t, idx, exp));
386
398
return s;
387
399
}
388
400
unsigned s0 = buildLattices (merger, op, merger.exp (exp).e0 , idx);
@@ -502,7 +514,7 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
502
514
if (merger.exp (exp).kind == Kind::kTensor )
503
515
return genTensorLoad (merger, codegen, rewriter, op, merger.exp (exp).e0 );
504
516
else if (merger.exp (exp).kind == Kind::kInvariant )
505
- return op. getParentRegion ()-> front (). getArgument ( merger.exp (exp).e0 ) ;
517
+ return merger.exp (exp).val ;
506
518
Value v0 = genExp (merger, codegen, rewriter, op, merger.exp (exp).e0 );
507
519
Value v1 = genExp (merger, codegen, rewriter, op, merger.exp (exp).e1 );
508
520
switch (merger.exp (exp).kind ) {
0 commit comments