@@ -61,36 +61,34 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
61
61
// / Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
62
62
// / the `producer` to use in the fused operation given the indexing map of the
63
63
// / result of the producer in the consumer.
64
- static void getIndexingMapOfProducerOperandsInFusedOp (
65
- LinalgOp producer , AffineMap fusedConsumerArgIndexMap ,
66
- SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs ) {
64
+ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp (
65
+ OpOperand &producerOpOperand , AffineMap producerResultIndexMap ,
66
+ AffineMap fusedConsumerArgIndexMap ) {
67
67
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
68
68
// from consumer loop -> consumer arg tensor index/producer result tensor
69
69
// index. The fused loop is same as the consumer loop. For each producer arg
70
70
// the indexing map to be computed is a map from consumer loop -> producer
71
71
// arg tensor index.
72
-
73
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap (0 );
74
72
// producerResultIndexMap is a map from producer loop -> tensor index.
75
73
// Compute the inverse to get map from tensor index -> producer loop.
76
74
// The inverse is a map from producer result tensor index -> producer loop.
77
75
AffineMap invProducerResultIndexMap =
78
76
inversePermutation (producerResultIndexMap);
79
77
assert (invProducerResultIndexMap &&
80
78
" expected producer result indexig map to be invertible" );
81
- for ( unsigned argNum : llvm::seq< unsigned >( 0 , producer. getNumInputs ())) {
82
- // argMap is a map from producer loop -> producer arg tensor index.
83
- AffineMap argMap = producer. getInputIndexingMap (argNum);
84
-
85
- // Compose argMap with invProducerResultIndexMap to get a map from
86
- // producer result tensor index -> producer arg tensor index.
87
- AffineMap t1 = argMap. compose ( invProducerResultIndexMap);
88
-
89
- // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
90
- // consumer loop/ fused loop -> producer arg tensor index.
91
- AffineMap indexingMap = t1. compose ( fusedConsumerArgIndexMap);
92
- fusedOpIndexingMapAttrs. push_back ( AffineMapAttr::get (indexingMap));
93
- }
79
+
80
+ LinalgOp producer = cast<LinalgOp>(producerOpOperand. getOwner ());
81
+ // argMap is a map from producer loop -> producer arg tensor index.
82
+ AffineMap argMap =
83
+ producer. getIndexingMap (producerOpOperand. getOperandNumber ());
84
+
85
+ // Compose argMap with invProducerResultIndexMap to get a map from
86
+ // producer result tensor index -> producer arg tensor index.
87
+ AffineMap t1 = argMap. compose (invProducerResultIndexMap);
88
+
89
+ // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
90
+ // consumer loop/ fused loop -> producer arg tensor index.
91
+ return t1. compose (fusedConsumerArgIndexMap);
94
92
}
95
93
96
94
// / Generate the region of the fused tensor operation. The region of the fused
@@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
163
161
.drop_front (numProducerIndices)
164
162
.take_front (producer.getNumInputs ()))
165
163
mapper.map (bbArg, fusedBlock->addArgument (bbArg.getType ()));
164
+
165
+ // 4.b. Producer output operand/map that is fused needs to be mapped to the
166
+ // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
167
+ assert (producer->getNumResults () == 1 && " expected single result producer" );
168
+ if (producer.isInitTensor (&producer.getOutputOpOperands ()[0 ])) {
169
+ BlockArgument bbArg =
170
+ producerBlock.getArguments ()
171
+ .drop_front (numConsumerIndices + producer.getNumInputs ())
172
+ // TODO: bbArg index of
173
+ .front ();
174
+ mapper.map (bbArg, fusedBlock->addArgument (bbArg.getType ()));
175
+ }
166
176
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
167
177
for (BlockArgument bbArg : consumerBlock.getArguments ()
168
178
.drop_front (numConsumerIndices)
@@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
221
231
!controlFn (producer->getResult (0 ), consumerOpOperand))
222
232
return llvm::None;
223
233
224
- unsigned numFusedOperands =
225
- producer.getNumInputs () + consumer.getNumInputs () - 1 ;
226
-
227
- // Compute the fused operands list,
228
- SmallVector<Value, 2 > fusedOperands;
229
- fusedOperands.reserve (numFusedOperands);
230
- auto consumerOperands = consumer.getInputs ();
231
- auto producerOperands = producer.getInputs ();
232
- fusedOperands.assign (consumerOperands.begin (),
233
- std::next (consumerOperands.begin (), consumerIdx));
234
- fusedOperands.append (producerOperands.begin (), producerOperands.end ());
235
- fusedOperands.append (std::next (consumerOperands.begin (), consumerIdx + 1 ),
236
- consumerOperands.end ());
237
-
238
- // Compute indexing_maps for the fused operation. The indexing_maps for the
239
- // operands of the consumers that aren't fused are the same. The
240
- // indexing_maps for the producers need to be computed based on the
241
- // indexing_map of the operand at consumerIdx in the consumer.
242
- SmallVector<Attribute, 4 > fusedIndexMaps;
243
- auto consumerIndexMaps = consumer.indexing_maps ();
244
- fusedIndexMaps.reserve (fusedOperands.size () + consumer.getNumOutputs ());
245
- fusedIndexMaps.assign (consumerIndexMaps.begin (),
246
- std::next (consumerIndexMaps.begin (), consumerIdx));
247
- // Compute indexing maps for the producer args in the fused operation.
248
- getIndexingMapOfProducerOperandsInFusedOp (
249
- producer, consumer.getInputIndexingMap (consumerIdx), fusedIndexMaps);
250
-
251
- // Append the indexing maps for the remaining consumer operands.
252
- fusedIndexMaps.append (std::next (consumerIndexMaps.begin (), consumerIdx + 1 ),
253
- consumerIndexMaps.end ());
234
+ // TODO: allow fusing the producer of an output operand.
235
+ assert (consumerIdx < consumer.getNumInputs () &&
236
+ " expected producer of input operand" );
237
+
238
+ // Compute the fused operands list and indexing maps.
239
+ SmallVector<Value> fusedOperands;
240
+ SmallVector<AffineMap> fusedIndexMaps;
241
+ fusedOperands.reserve (producer->getNumOperands () +
242
+ consumer->getNumOperands ());
243
+ fusedIndexMaps.reserve (producer->getNumOperands () +
244
+ consumer->getNumOperands ());
245
+ // In the following, numbering matches that of `generateFusedTensorOpRegion`.
246
+ // 3. Consumer input operands/maps up to consumerIdx (exclusive).
247
+ llvm::append_range (fusedOperands,
248
+ consumer.getInputs ().take_front (consumerIdx));
249
+ llvm::append_range (
250
+ fusedIndexMaps,
251
+ ArrayRef<AffineMap>{consumer.getInputIndexingMaps ()}.take_front (
252
+ consumerIdx));
253
+ // 4. Splice in producer's input operands/maps.
254
+ llvm::append_range (fusedOperands, producer.getInputs ());
255
+ assert (producer->getNumResults () == 1 && " expected single result producer" );
256
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap (0 );
257
+ for (auto &inputOpOperand : producer.getInputOpOperands ()) {
258
+ // Compute indexing maps for the producer args in the fused operation.
259
+ AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp (
260
+ inputOpOperand, producerResultIndexMap,
261
+ consumer.getInputIndexingMap (consumerIdx));
262
+ fusedIndexMaps.push_back (map);
263
+ }
264
+ // 4.b. Producer output operand/map that is fused needs to be passed if it is
265
+ // an "initTensor" (i.e. its value is actually read).
266
+ assert (producer->getNumResults () == 1 && " expected single result producer" );
267
+ if (producer.isInitTensor (&producer.getOutputOpOperands ()[0 ])) {
268
+ llvm::append_range (fusedOperands, producer.getOutputs ().take_front ());
269
+ // Compute indexing maps for the producer args in the fused operation.
270
+ AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp (
271
+ producer.getOutputOpOperands ().front (), producerResultIndexMap,
272
+ consumer.getOutputIndexingMap (0 ));
273
+ fusedIndexMaps.push_back (map);
274
+ }
275
+ // 5. Remaining consumer's input operands/maps (drop past index
276
+ // `consumerIdx`).
277
+ llvm::append_range (fusedOperands,
278
+ consumer.getInputs ().drop_front (consumerIdx + 1 ));
279
+ llvm::append_range (
280
+ fusedIndexMaps,
281
+ ArrayRef<AffineMap>{consumer.getInputIndexingMaps ()}.drop_front (
282
+ consumerIdx + 1 ));
283
+ // 6. All of consumer's output operands (skip operands: added by the builder).
284
+ // llvm::append_range(fusedOperands, consumer.getOutputs());
285
+ llvm::append_range (fusedIndexMaps, consumer.getOutputIndexingMaps ());
286
+ // 7. All of producer's output operands/maps except the one fused.
287
+ // TODO: allow fusion of multi-result producers.
288
+ assert (producer->getNumResults () == 1 && " expected single result producer" );
254
289
255
290
// Generate the fused op.
256
- LinalgOp fusedOp;
291
+ Operation * fusedOp;
257
292
if (isa<GenericOp>(producer.getOperation ()) &&
258
293
isa<GenericOp>(consumer.getOperation ())) {
259
- fusedOp =
260
- rewriter
261
- .create <GenericOp>(consumer.getLoc (), consumer->getResultTypes (),
262
- /* inputs=*/ fusedOperands,
263
- // TODO: handle outputs.
264
- consumer.getOutputs (),
265
- rewriter.getArrayAttr (fusedIndexMaps),
266
- consumer.iterator_types (),
267
- /* doc=*/ nullptr ,
268
- /* library_call=*/ nullptr ,
269
- /* sparse=*/ nullptr )
270
- .getOperation ();
294
+ fusedOp = rewriter.create <GenericOp>(
295
+ consumer.getLoc (), consumer->getResultTypes (),
296
+ /* inputs=*/ fusedOperands,
297
+ // TODO: handle outputs.
298
+ consumer.getOutputs (), rewriter.getAffineMapArrayAttr (fusedIndexMaps),
299
+ consumer.iterator_types (),
300
+ /* doc=*/ nullptr ,
301
+ /* library_call=*/ nullptr ,
302
+ /* sparse=*/ nullptr );
271
303
} else {
272
- fusedOp =
273
- rewriter
274
- .create <IndexedGenericOp>(
275
- consumer.getLoc (), consumer->getResultTypes (),
276
- /* inputs=*/ fusedOperands,
277
- // TODO: handle outputs.
278
- consumer.getOutputs (), rewriter.getArrayAttr (fusedIndexMaps),
279
- consumer.iterator_types (),
280
- /* doc=*/ nullptr ,
281
- /* library_call=*/ nullptr ,
282
- /* sparse=*/ nullptr )
283
- .getOperation ();
304
+ fusedOp = rewriter.create <IndexedGenericOp>(
305
+ consumer.getLoc (), consumer->getResultTypes (),
306
+ /* inputs=*/ fusedOperands,
307
+ // TODO: handle outputs.
308
+ consumer.getOutputs (), rewriter.getAffineMapArrayAttr (fusedIndexMaps),
309
+ consumer.iterator_types (),
310
+ /* doc=*/ nullptr ,
311
+ /* library_call=*/ nullptr ,
312
+ /* sparse=*/ nullptr );
284
313
}
285
314
286
315
// Construct an AffineMap from consumer loops to producer loops.
287
316
// consumer loop -> tensor index
288
317
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap (consumerIdx);
289
- // producer loop -> tensor index
290
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap (0 );
291
318
// tensor index -> producer loop
292
319
AffineMap invProducerResultIndexMap =
293
320
inversePermutation (producerResultIndexMap);
@@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
297
324
AffineMap consumerToProducerLoopsMap =
298
325
invProducerResultIndexMap.compose (consumerResultIndexMap);
299
326
300
- generateFusedElementwiseOpRegion (rewriter, fusedOp. getOperation () , producer,
301
- consumer, consumerToProducerLoopsMap ,
302
- consumerIdx, consumer.getNumLoops ());
327
+ generateFusedElementwiseOpRegion (rewriter, fusedOp, producer, consumer ,
328
+ consumerToProducerLoopsMap, consumerIdx ,
329
+ consumer.getNumLoops ());
303
330
return SmallVector<Value, 1 >(fusedOp->getResults ());
304
331
}
305
332
0 commit comments