@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
328
328
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
329
329
// ===----------------------------------------------------------------------===//
330
330
331
+ // / In a given set vector, get the position of a particular element.
332
+ std::optional<int > getPositionIn (const llvm::SetVector<unsigned > &reductionDims,
333
+ unsigned value) {
334
+ for (auto [index, reductionDim] : llvm::enumerate (reductionDims)) {
335
+ if (reductionDim == value) {
336
+ return index;
337
+ }
338
+ }
339
+ return std::nullopt;
340
+ }
341
+
331
342
// / Return an AffineMaps to use for the `outs` operands of the linalg op
332
343
// / generated for partial results. The new AffineMap is the AffineMap of the
333
344
// / untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,79 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
348
359
return partialReductionMaps;
349
360
}
350
361
351
- // / Return the slice of the `initValue` to use as input to the partial reduction
352
- // / op generated.
353
- static Operation *getInitSliceForOuterReduction (
354
- OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
362
+ struct InitSliceInfo {
363
+ SmallVector<int64_t > resultShape;
364
+ SmallVector<OpFoldResult> offsets;
365
+ SmallVector<OpFoldResult> sizes;
366
+ SmallVector<OpFoldResult> strides;
367
+ };
368
+
369
+ // / Return the result type, offsets, sizes and strides of the slice of the
370
+ // / `initValue` to use as input to the partial reduction op generated with
371
+ // / outer reduction strategy.
372
+ static InitSliceInfo getInitSliceInfoForOuterReduction (
373
+ MLIRContext *context, ArrayRef<OpFoldResult> offsets,
355
374
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
356
375
AffineMap partialReductionMap) {
357
376
int64_t initRank = partialReductionMap.getNumResults ();
358
377
SmallVector<OpFoldResult> initOffsets, initSizes;
359
- SmallVector<OpFoldResult> initStrides (initRank, b.getIndexAttr (1 ));
378
+ Attribute zero = IntegerAttr::get (IndexType::get (context), 0 );
379
+ Attribute one = IntegerAttr::get (IndexType::get (context), 1 );
380
+ SmallVector<OpFoldResult> initStrides (initRank, one);
360
381
for (AffineExpr dimExpr : partialReductionMap.getResults ()) {
361
382
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
362
383
if (reductionDims.contains (dim)) {
363
- initOffsets.push_back (b. getIndexAttr ( 0 ) );
384
+ initOffsets.push_back (zero );
364
385
} else {
365
386
initOffsets.push_back (offsets[dim]);
366
387
}
367
388
initSizes.push_back (sizes[dim]);
368
389
}
369
- // TODO: Use SubsetExtractOpInterface here once available.
370
- auto extractSlice = b.create <tensor::ExtractSliceOp>(
371
- loc, initValue, initOffsets, initSizes, initStrides);
372
- return extractSlice;
390
+ SmallVector<int64_t > resultShape;
391
+ std::tie (resultShape, std::ignore) = decomposeMixedValues (initSizes);
392
+ return {resultShape, initOffsets, initSizes, initStrides};
393
+ }
394
+
395
+ // / Return the result type, offsets, sizes and strides of the slice of the
396
+ // / `initValue` to use as input to the partial reduction op generated with
397
+ // / outer parallel strategy.
398
+ static InitSliceInfo getInitSliceInfoForOuterParallel (
399
+ MLIRContext *context, ValueRange ivs, ArrayRef<OpFoldResult> offsets,
400
+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
401
+ AffineMap partialReductionMap) {
402
+ int64_t initRank = partialReductionMap.getNumResults ();
403
+ SmallVector<OpFoldResult> initOffsets, initSizes;
404
+ Attribute one = IntegerAttr::get (IndexType::get (context), 1 );
405
+ SmallVector<OpFoldResult> initStrides (initRank, one);
406
+ SmallVector<OpFoldResult> resultShape;
407
+ for (AffineExpr dimExpr : partialReductionMap.getResults ()) {
408
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
409
+ if (std::optional<int > dimPos = getPositionIn (reductionDims, dim)) {
410
+ initOffsets.push_back (ivs[dimPos.value ()]);
411
+ initSizes.push_back (one);
412
+ } else {
413
+ initOffsets.push_back (offsets[dim]);
414
+ initSizes.push_back (sizes[dim]);
415
+ resultShape.push_back (sizes[dim]);
416
+ }
417
+ }
418
+ SmallVector<int64_t > staticShapes;
419
+ std::tie (staticShapes, std::ignore) = decomposeMixedValues (resultShape);
420
+ return {staticShapes, initOffsets, initSizes, initStrides};
421
+ }
422
+
423
+ static InitSliceInfo getInitSliceInfo (
424
+ MLIRContext *context, ReductionTilingStrategy strategy, ValueRange ivs,
425
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
426
+ const SetVector<unsigned > &reductionDims, AffineMap partialReductionMap) {
427
+ if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
428
+ return getInitSliceInfoForOuterReduction (
429
+ context, offsets, sizes, reductionDims, partialReductionMap);
430
+ }
431
+ assert (strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
432
+ " unexpected ReductionTilingStrategy" );
433
+ return getInitSliceInfoForOuterParallel (context, ivs, offsets, sizes,
434
+ reductionDims, partialReductionMap);
373
435
}
374
436
375
437
// / External model implementation of PartialReductionInterface for
@@ -439,18 +501,11 @@ struct LinalgOpPartialReductionInterface
439
501
return inits;
440
502
}
441
503
442
- FailureOr<TilingResult>
443
- tileToPartialReduction (Operation *op, OpBuilder &b, Location loc,
444
- ReductionTilingStrategy tilingStrategy,
445
- ValueRange init, ArrayRef<OpFoldResult> offsets,
446
- ArrayRef<OpFoldResult> sizes,
447
- const SetVector<unsigned > &reductionDims) const {
448
- if (tilingStrategy !=
449
- ReductionTilingStrategy::PartialReductionOuterReduction) {
450
- // TODO: Add support for `PartialReductionOuterParallel` strategy.
451
- return op->emitOpError (" unsupported partial reduction tiling with "
452
- " `PartialReductionOuterParallel` strategy" );
453
- }
504
+ FailureOr<TilingResult> tileToPartialReduction (
505
+ Operation *op, OpBuilder &b, Location loc,
506
+ ReductionTilingStrategy tilingStrategy, ValueRange init, ValueRange ivs,
507
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
508
+ const SetVector<unsigned > &reductionDims) const {
454
509
OpBuilder::InsertionGuard guard (b);
455
510
auto linalgOp = cast<LinalgOp>(op);
456
511
@@ -459,7 +514,16 @@ struct LinalgOpPartialReductionInterface
459
514
460
515
// Step 1. Extend init maps to have reduction dimension dims, since we
461
516
// are converting them to parallel dimensions.
462
- SmallVector<AffineMap> newInitMaps = partialReductionMaps;
517
+ SmallVector<AffineMap> newInitMaps;
518
+ if (tilingStrategy ==
519
+ ReductionTilingStrategy::PartialReductionOuterReduction) {
520
+ newInitMaps = llvm::to_vector (partialReductionMaps);
521
+ } else {
522
+ newInitMaps = llvm::map_to_vector (
523
+ linalgOp.getDpsInitsMutable (), [&](OpOperand &opOperand) {
524
+ return linalgOp.getMatchingIndexingMap (&opOperand);
525
+ });
526
+ }
463
527
464
528
// Step 2a: Extract a slice of the input operands.
465
529
SmallVector<Value> tiledInputs = makeTiledShapes (
@@ -473,10 +537,17 @@ struct LinalgOpPartialReductionInterface
473
537
SmallVector<Value, 1 > tiledInits;
474
538
for (auto [partialReductionMap, valueToTile] :
475
539
llvm::zip_equal (partialReductionMaps, init)) {
476
- Operation *sliceOp =
477
- getInitSliceForOuterReduction (b, loc, valueToTile, offsets, sizes,
478
- reductionDims, partialReductionMap);
479
- tiledInits.push_back (sliceOp->getResult (0 ));
540
+ InitSliceInfo sliceInfo =
541
+ getInitSliceInfo (b.getContext (), tilingStrategy, ivs, offsets, sizes,
542
+ reductionDims, partialReductionMap);
543
+ auto valueToTileType = cast<RankedTensorType>(valueToTile.getType ());
544
+ RankedTensorType sliceResultType = RankedTensorType::get (
545
+ sliceInfo.resultShape , valueToTileType.getElementType (),
546
+ valueToTileType.getEncoding ());
547
+ auto sliceOp = b.create <tensor::ExtractSliceOp>(
548
+ loc, sliceResultType, valueToTile, sliceInfo.offsets , sliceInfo.sizes ,
549
+ sliceInfo.strides );
550
+ tiledInits.push_back (sliceOp.getResult ());
480
551
generatedSlices.push_back (sliceOp);
481
552
}
482
553
@@ -491,19 +562,31 @@ struct LinalgOpPartialReductionInterface
491
562
// Step 3. Change the reduction dim iterator types.
492
563
SmallVector<utils::IteratorType> newIteratorTypes =
493
564
linalgOp.getIteratorTypesArray ();
494
- for (int dim : reductionDims)
495
- newIteratorTypes[dim] = utils::IteratorType::parallel;
565
+ if (tilingStrategy ==
566
+ ReductionTilingStrategy::PartialReductionOuterReduction) {
567
+ for (int dim : reductionDims)
568
+ newIteratorTypes[dim] = utils::IteratorType::parallel;
569
+ }
496
570
497
571
// Step 4. Create the new generic op.
572
+ Operation *partialReductionOp;
498
573
auto resultTypes = ValueRange (tiledInits).getTypes ();
499
- auto genericOp = b.create <GenericOp>(loc, resultTypes, tiledInputs,
500
- tiledInits, newMaps, newIteratorTypes);
501
- IRMapping mapping;
502
- op->getRegion (0 ).cloneInto (&genericOp.getRegion (),
503
- genericOp.getRegion ().begin (), mapping);
574
+ if (tilingStrategy ==
575
+ ReductionTilingStrategy::PartialReductionOuterReduction) {
576
+ auto genericOp = b.create <GenericOp>(
577
+ loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
578
+ IRMapping mapping;
579
+ op->getRegion (0 ).cloneInto (&genericOp.getRegion (),
580
+ genericOp.getRegion ().begin (), mapping);
581
+ partialReductionOp = genericOp.getOperation ();
582
+ } else {
583
+ SmallVector<Value> operands = std::move (tiledInputs);
584
+ llvm::append_range (operands, tiledInits);
585
+ partialReductionOp = mlir::clone (b, op, resultTypes, operands);
586
+ }
504
587
return TilingResult{
505
- {genericOp. getOperation () },
506
- llvm::map_to_vector (genericOp ->getResults (),
588
+ {partialReductionOp },
589
+ llvm::map_to_vector (partialReductionOp ->getResults (),
507
590
[](OpResult r) -> Value { return r; }),
508
591
generatedSlices};
509
592
}
@@ -557,27 +640,19 @@ struct LinalgOpPartialReductionInterface
557
640
}
558
641
559
642
LogicalResult getPartialResultTilePosition (
560
- Operation *op, OpBuilder &b, unsigned resultNumber,
561
- ArrayRef<OpFoldResult> offsets , ArrayRef<OpFoldResult> sizes ,
562
- const SetVector<unsigned > &reductionDims,
643
+ Operation *op, OpBuilder &b, unsigned resultNumber, ValueRange ivs,
644
+ ReductionTilingStrategy tilingStrategy , ArrayRef<OpFoldResult> offsets ,
645
+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
563
646
SmallVector<OpFoldResult> &resultOffsets,
564
647
SmallVector<OpFoldResult> &resultSizes) const {
565
648
auto linalgOp = cast<LinalgOp>(op);
566
649
SmallVector<AffineMap> partialReductionMaps =
567
650
getPartialResultAffineMaps (linalgOp, reductionDims);
568
-
569
- for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults ()) {
570
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
571
- resultSizes.push_back (sizes[dim]);
572
-
573
- if (llvm::is_contained (reductionDims, dim)) {
574
- // Reduction dims are reduced, and are always outputed in the same
575
- // place. So use offset 0 for them.
576
- resultOffsets.push_back (b.getIndexAttr (0 ));
577
- } else {
578
- resultOffsets.push_back (offsets[dim]);
579
- }
580
- }
651
+ InitSliceInfo sliceInfo =
652
+ getInitSliceInfo (b.getContext (), tilingStrategy, ivs, offsets, sizes,
653
+ reductionDims, partialReductionMaps[resultNumber]);
654
+ std::swap (resultOffsets, sliceInfo.offsets );
655
+ std::swap (resultSizes, sliceInfo.sizes );
581
656
582
657
return success ();
583
658
}
0 commit comments