@@ -69,38 +69,36 @@ struct ShapeDimension {
69
69
static ShapeDimension
70
70
getShapeDefiningLoopRange (LinalgOp op, unsigned loopDepth,
71
71
bool fromSubViewOpOnly = false ) {
72
- auto maps = op.indexing_maps ();
73
72
// Iterate over the inputs and outputs in order.
74
73
// Extract the subranges from the linearized ranges.
75
- for (auto en : llvm::enumerate ( op.getShapedOperands () )) {
74
+ for (OpOperand *opOperand : op.getInputAndOutputOperands ( )) {
76
75
// The method `getRangeFromOperandShape` requires using SubViewOp or
77
76
// SubTensorOps. If the value isnt defined from there continue.
78
77
// todo: The method should be adapted to get the values from
79
78
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
80
79
// currently returns a `linalg.range`. The fix here is to move this op to
81
80
// `std` dialect and add the method to `ViewInterface`.
82
81
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
83
- en. value ().getDefiningOp ()))
82
+ opOperand-> get ().getDefiningOp ()))
84
83
continue ;
85
84
86
- unsigned idx = en.index ();
87
- auto map = maps[idx].cast <AffineMapAttr>().getValue ();
88
- LLVM_DEBUG (llvm::dbgs ()
89
- << " getShapeDefiningLoopRange I/O idx: " << idx << " \n " );
85
+ AffineMap map = op.getTiedIndexingMap (opOperand);
86
+ LLVM_DEBUG (llvm::dbgs () << " getShapeDefiningLoopRange I/O idx: "
87
+ << opOperand->getOperandNumber () << " \n " );
90
88
LLVM_DEBUG (llvm::dbgs ()
91
89
<< " getShapeDefiningLoopRange map: " << map << " \n " );
92
- Value shape = en.value ();
93
90
SmallVector<Value, 8 > shapeRanges (map.getNumResults (), nullptr );
94
- for (auto en2 : llvm::enumerate (map.getResults ())) {
95
- auto dimExpr = en2 .value ().dyn_cast <AffineDimExpr>();
91
+ for (auto en : llvm::enumerate (map.getResults ())) {
92
+ auto dimExpr = en .value ().dyn_cast <AffineDimExpr>();
96
93
if (!dimExpr)
97
94
continue ;
98
- if (loopDepth == en2 .value ().cast <AffineDimExpr>().getPosition ()) {
95
+ if (loopDepth == en .value ().cast <AffineDimExpr>().getPosition ()) {
99
96
LLVM_DEBUG (llvm::dbgs () << " getShapeDefiningLoopRange loopDepth: "
100
97
<< loopDepth << " \n " );
101
- LLVM_DEBUG (llvm::dbgs ()
102
- << " getShapeDefiningLoopRange shape: " << shape << " \n " );
103
- return ShapeDimension{shape, static_cast <unsigned >(en2.index ())};
98
+ LLVM_DEBUG (llvm::dbgs () << " getShapeDefiningLoopRange shape: "
99
+ << opOperand->get () << " \n " );
100
+ return ShapeDimension{opOperand->get (),
101
+ static_cast <unsigned >(en.index ())};
104
102
}
105
103
}
106
104
}
@@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
122
120
// would need to add the intermediate results to `linalg.yield`. After that a
123
121
// canonicalization pass would move the unused output args of the `tiled_loop`
124
122
// to the `input` section.
125
- static SmallVector<Value, 4 > getTiledOperands (OpBuilder &b, LinalgOp producer) {
123
+ static SmallVector<Value> getTiledOperands (OpBuilder &b, LinalgOp producer) {
126
124
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock ()->getParentOp ());
127
125
if (!tiledLoop)
128
- return llvm::to_vector< 4 >( producer.getShapedOperands () );
126
+ return producer.getInputAndOutputOperands ( );
129
127
130
- SmallVector<Value, 4 > tiledOperands;
128
+ SmallVector<Value> tiledOperands;
131
129
assert (producer.hasTensorSemantics () &&
132
130
" only fusion on tensors is currently supported for TiledLinalgOp" );
133
131
134
- for (auto producerInput : producer.getInputTensors ()) {
135
- OpOperand *addedInput = tiledLoop.findInputOperand (producerInput);
132
+ for (OpOperand * producerInput : producer.getInputTensorOperands ()) {
133
+ OpOperand *addedInput = tiledLoop.findInputOperand (producerInput-> get () );
136
134
if (addedInput == nullptr )
137
- addedInput = &tiledLoop.appendInputOperand (b, producerInput);
135
+ addedInput = &tiledLoop.appendInputOperand (b, producerInput-> get () );
138
136
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument (*addedInput);
139
137
tiledOperands.push_back (addedBlockArg);
140
138
}
141
- for (auto &en : llvm::enumerate (producer.getOutputTensors ())) {
142
- Value producerOutput = en.value ();
143
-
144
- Value result = producer->getResult (en.index ());
139
+ for (OpOperand *producerOutput : producer.getOutputTensorOperands ()) {
140
+ OpResult result = producer.getTiedOpResult (producerOutput);
145
141
OpOperand *resultInputOperand = tiledLoop.findInputOperand (result);
146
142
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand (result);
147
143
assert ((resultInputOperand != nullptr ) ^ (resultOutputOperand != nullptr ) &&
@@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
152
148
int opNumber = isInput ? resultInputOperand->getOperandNumber ()
153
149
: resultOutputOperand->getOperandNumber ();
154
150
155
- OpOperand *addedOutput = tiledLoop.findOutputOperand (producerOutput);
151
+ OpOperand *addedOutput = tiledLoop.findOutputOperand (producerOutput-> get () );
156
152
if (addedOutput == nullptr )
157
- addedOutput = isInput ? &tiledLoop.appendInputOperand (b, producerOutput)
158
- : &tiledLoop.appendOutputOperand (b, producerOutput);
153
+ addedOutput =
154
+ isInput ? &tiledLoop.appendInputOperand (b, producerOutput->get ())
155
+ : &tiledLoop.appendOutputOperand (b, producerOutput->get ());
159
156
160
157
OpOperand &resultOperand = tiledLoop->getOpOperand (opNumber);
161
158
auto addedBlockArg = tiledLoop.getTiedBlockArgument (*addedOutput);
@@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
200
197
}
201
198
202
199
SmallVector<Value, 8 > clonedShapes;
203
- clonedShapes.reserve (producer.getNumShapedOperands ());
200
+ clonedShapes.reserve (producer.getNumInputsAndOutputs ());
204
201
205
202
// Compute subranges for all tensor input/output operands.
206
203
clonedShapes.append (makeTiledShapes (b, loc, producer,
@@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
267
264
llvm_unreachable (" SubviewOp or SubTensorOp expected" );
268
265
}
269
266
270
- // / Fuses the producer of `producerIdx` into the loop immediately enclosing
271
- // / `consumer`. This is achieved by "recomputing" the `producer` at the time it
272
- // / is needed just before the `consumer.
273
- // /
274
- // / Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
275
- // / 2 cases:
276
- // / 1. Buffer case: `producerIdx` is the index of the buffer in
277
- // / `producer.getOutputBuffers()`.
278
- // / 2. Tensor case: `producerIdx` is the index of the tensor in
279
- // / `producer.getResults()`.
267
+ // / Fuses the producer into the loop immediately enclosing the consumer.
268
+ // / This is achieved by "recomputing" the producer at the time it
269
+ // / is needed just before the consumer.
280
270
static LinalgOp fuse (OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
281
271
OpOperand &consumerOpOperand) {
282
272
LLVM_DEBUG (llvm::dbgs () << " Producer map: " << producerMap << " \n " );
@@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
548
538
OpBuilder::InsertionGuard g (b);
549
539
b.setInsertionPoint (consumerOp);
550
540
LLVM_DEBUG (llvm::dbgs () << " Fuse into consumer: " << *consumerOp << " \n " );
541
+ OpOperand *opOperand =
542
+ producerOp.getOutputOperand (producerOpResult.getResultNumber ());
551
543
LinalgOp fusedProducer =
552
- fuse (b, producerOp,
553
- producerOp.getOutputIndexingMap (producerOpResult.getResultNumber ()),
544
+ fuse (b, producerOp, producerOp.getTiedIndexingMap (opOperand),
554
545
consumerOpOperand);
555
546
556
547
// Replace use.
@@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
770
761
FusableOpDependencesTy fusableDependences;
771
762
DenseMap<Operation *, SmallVector<AffineMap, 1 >> fusedProducerIndexingMap;
772
763
for (LinalgOp op : reverse (ops)) {
773
- for (OpOperand & opOperand : op.getShapedOpOperands ()) {
764
+ for (OpOperand * opOperand : op.getInputAndOutputOperands ()) {
774
765
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
775
- fusableDependence = findFusableProducer (opOperand, dependenceGraph);
766
+ fusableDependence = findFusableProducer (* opOperand, dependenceGraph);
776
767
if (!fusableDependence)
777
768
continue ;
778
769
// Canonicalize indexed generic ops before fusion.
@@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
905
896
// To keep the second type of information while letting the unfused op die
906
897
// unused, we need to forward the producer output operand.
907
898
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops .front ())) {
908
- for (auto &operand : forOp.getIterOpOperands ())
909
- if (auto opResult = operand.get ().dyn_cast <OpResult>())
910
- if (opResult.getOwner () == origOp)
911
- operand.set (origOp.getOutputTensors ()[opResult.getResultNumber ()]);
899
+ for (auto &operand : forOp.getIterOpOperands ()) {
900
+ if (auto opResult = operand.get ().dyn_cast <OpResult>()) {
901
+ if (opResult.getOwner () == origOp) {
902
+ Value output =
903
+ origOp.getOutputOperand (opResult.getResultNumber ())->get ();
904
+ assert (output.getType ().isa <RankedTensorType>());
905
+ operand.set (output);
906
+ }
907
+ }
908
+ }
912
909
}
913
910
}
914
911
return fusedOps;
0 commit comments