@@ -324,7 +324,20 @@ struct LinalgOpTilingInterface
324
324
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325
325
// ===----------------------------------------------------------------------===//
326
326
327
- // / External model implementation of PartialReductionInterface for LinalgOps.
327
+ static AffineMap getPartialResultAffineMap (LinalgOp linalgOp,
328
+ ArrayRef<int > reductionDims,
329
+ unsigned resultNumber) {
330
+ AffineMap map =
331
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (resultNumber));
332
+ for (int redPos : reductionDims) {
333
+ map = map.insertResult (getAffineDimExpr (redPos, linalgOp.getContext ()),
334
+ map.getNumResults ());
335
+ }
336
+ return map;
337
+ }
338
+
339
+ // / External model implementation of PartialReductionInterface for
340
+ // / LinalgOps.
328
341
template <typename LinalgOpTy>
329
342
struct LinalgOpPartialReductionInterface
330
343
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +351,24 @@ struct LinalgOpPartialReductionInterface
338
351
if (linalgOp.hasPureBufferSemantics ())
339
352
return op->emitOpError (" expected operation to have tensor semantics" );
340
353
354
+ // LinalgOp implements TilingInterface.
355
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
356
+ SmallVector<OpFoldResult> shape =
357
+ llvm::map_to_vector (tilingInterfaceOp.getIterationDomain (b),
358
+ [](Range x) { return x.size ; });
359
+
360
+ SmallVector<OpFoldResult> tiledShape;
361
+ for (auto [tileSize, dimSize] : llvm::zip_equal (sizes, shape)) {
362
+ if (isZeroIndex (tileSize)) {
363
+ tiledShape.push_back (dimSize);
364
+ } else {
365
+ tiledShape.push_back (tileSize);
366
+ }
367
+ }
368
+
341
369
SmallVector<Value> inits;
342
370
for (int initIdx = 0 , e = linalgOp.getNumDpsInits (); initIdx < e;
343
371
++initIdx) {
344
- // Insert the new parallel dimension based on the index of the reduction
345
- // loops. This could be controlled by user for more flexibility.
346
372
SmallVector<Operation *, 4 > combinerOps;
347
373
if (!matchReduction (linalgOp.getRegionOutputArgs (), initIdx,
348
374
combinerOps) ||
@@ -355,33 +381,19 @@ struct LinalgOpPartialReductionInterface
355
381
return op->emitOpError (
356
382
" Failed to get an identity value for the reduction operation." );
357
383
358
- ArrayRef<int64_t > oldShape =
359
- linalgOp.getShape (linalgOp.getDpsInitOperand (initIdx));
360
-
361
- // Calculate the new shape, we insert the new dimensions based on the
362
- // index of the reduction dimensions.
363
- SmallVector<int64_t > newOutputShape;
364
- SmallVector<Value> dynamicDims;
365
- int64_t currReductionDims = 0 ;
366
- DenseSet<int > reductionDimsSet (reductionDims.begin (),
367
- reductionDims.end ());
368
- for (int64_t idx :
369
- llvm::seq<int64_t >(0 , oldShape.size () + reductionDims.size ())) {
370
- if (reductionDimsSet.contains (idx)) {
371
- dispatchIndexOpFoldResults (sizes[idx], dynamicDims, newOutputShape);
372
- currReductionDims++;
373
- continue ;
374
- }
375
- int64_t oldIdx = idx - currReductionDims;
376
- int64_t dim = oldShape[oldIdx];
377
- newOutputShape.push_back (dim);
378
- if (ShapedType::isDynamic (dim))
379
- dynamicDims.push_back (b.create <tensor::DimOp>(
380
- loc, linalgOp.getDpsInitOperand (initIdx)->get (), oldIdx));
384
+ // Append the new partial result dimensions.
385
+ AffineMap partialMap =
386
+ getPartialResultAffineMap (linalgOp, reductionDims, initIdx);
387
+ SmallVector<OpFoldResult> partialResultShape;
388
+ for (AffineExpr dimExpr : partialMap.getResults ()) {
389
+ auto dim = cast<AffineDimExpr>(dimExpr);
390
+ partialResultShape.push_back (tiledShape[dim.getPosition ()]);
381
391
}
382
- Value emptyTensor = b.create <tensor::EmptyOp>(
383
- loc, newOutputShape,
384
- linalgOp.getRegionOutputArgs ()[initIdx].getType (), dynamicDims);
392
+
393
+ Type elType =
394
+ getElementTypeOrSelf (linalgOp->getResult (initIdx).getType ());
395
+ Value emptyTensor =
396
+ b.create <tensor::EmptyOp>(loc, partialResultShape, elType);
385
397
Value constantOp = b.create <arith::ConstantOp>(loc, *identity);
386
398
auto identityTensor =
387
399
b.create <linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +419,7 @@ struct LinalgOpPartialReductionInterface
407
419
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408
420
// this with a for range loop when we have it.
409
421
AffineMap newMap =
410
- linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
411
- for (int redPos : reductionDims) {
412
- newMap = newMap.insertResult (b.getAffineDimExpr (redPos),
413
- newMap.getNumResults ());
414
- }
422
+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
415
423
newInitMaps.push_back (newMap);
416
424
}
417
425
@@ -476,29 +484,74 @@ struct LinalgOpPartialReductionInterface
476
484
Location loc, ValueRange partialReduce,
477
485
ArrayRef<int > reductionDims) const {
478
486
auto linalgOp = cast<LinalgOp>(op);
479
- SmallVector<int64_t > reductionDimsInt64 (reductionDims);
480
- auto reduction = b.create <linalg::ReduceOp>(
481
- loc, partialReduce, linalgOp.getDpsInits (), reductionDimsInt64,
482
- [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483
- int64_t numInits = linalgOp.getNumDpsInits ();
484
- SmallVector<Value> yieldedValues;
485
- for (int idx : llvm::seq<int >(0 , numInits)) {
487
+
488
+ // Permute the reduction dims as permuted by the partial result map.
489
+
490
+ int64_t numInits = linalgOp.getNumDpsInits ();
491
+ SmallVector<Operation *> mergeOperations;
492
+ SmallVector<Value> replacements;
493
+ for (int idx : llvm::seq (numInits)) {
494
+ // linalg.reduce's iteration space is the result's iteration space (and
495
+ // not the operations iteration space). To account for this, permute the
496
+ // reduction dimensions based on the partial result map.
497
+ AffineMap partialMap =
498
+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
499
+ SmallVector<int64_t > partialReductionDims;
500
+ for (auto [resultNum, dimExpr] :
501
+ llvm::enumerate (partialMap.getResults ())) {
502
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
503
+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
504
+ partialReductionDims.push_back (resultNum);
505
+ }
506
+ }
507
+
508
+ Value partialResult = partialReduce[idx];
509
+ Value init = linalgOp.getDpsInits ()[idx];
510
+
511
+ auto reduction = b.create <linalg::ReduceOp>(
512
+ loc, partialResult, init, partialReductionDims,
513
+ [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486
514
// Get the combiner op.
487
515
SmallVector<Operation *, 4 > combinerOps;
488
516
matchReduction (linalgOp.getRegionOutputArgs (), idx, combinerOps);
489
517
Operation *clonedReductionOp = b.clone (*combinerOps[0 ]);
490
518
// Combine the input at idx and output at numInits + idx.
491
- clonedReductionOp->setOperand (0 , inputs[idx]);
492
- clonedReductionOp->setOperand (1 , inputs[numInits + idx]);
493
- // Yield.
494
- yieldedValues.push_back (clonedReductionOp->getResult (0 ));
495
- }
496
- b.create <linalg::YieldOp>(loc, yieldedValues);
497
- });
498
- return MergeResult{
499
- {reduction.getOperation ()},
500
- llvm::map_to_vector (reduction->getResults (),
501
- [](OpResult r) -> Value { return r; })};
519
+ clonedReductionOp->setOperand (0 , inputs[0 ]);
520
+ clonedReductionOp->setOperand (1 , inputs[1 ]);
521
+ b.create <linalg::YieldOp>(loc, clonedReductionOp->getResult (0 ));
522
+ });
523
+
524
+ mergeOperations.push_back (reduction);
525
+ replacements.push_back (reduction->getResult (0 ));
526
+ }
527
+
528
+ return MergeResult{mergeOperations, replacements};
529
+ }
530
+
531
+ LogicalResult getPartialResultTilePosition (
532
+ Operation *op, OpBuilder &b, unsigned resultNumber,
533
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
534
+ SmallVector<OpFoldResult> &resultOffsets,
535
+ SmallVector<OpFoldResult> &resultSizes,
536
+ ArrayRef<int > reductionDims) const {
537
+ auto linalgOp = cast<LinalgOp>(op);
538
+
539
+ AffineMap partialMap =
540
+ getPartialResultAffineMap (linalgOp, reductionDims, resultNumber);
541
+ for (AffineExpr dimExpr : partialMap.getResults ()) {
542
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
543
+ resultSizes.push_back (sizes[dim]);
544
+
545
+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
546
+ // Reduction dims are reduced, and are always outputed in the same
547
+ // place. So use offset 0 for them.
548
+ resultOffsets.push_back (b.getIndexAttr (0 ));
549
+ } else {
550
+ resultOffsets.push_back (offsets[dim]);
551
+ }
552
+ }
553
+
554
+ return success ();
502
555
}
503
556
};
504
557
0 commit comments