Skip to content

Commit b03fcbe

Browse files
[mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF.
Following up from #143467, this PR adds support for `ReductionTilingStrategy::PartialReductionOuterParallel` to `tileUsingSCF`. The implementation of `PartialReductionTilingInterface` for `Linalg` ops has been updated to support this strategy as well. This makes the `tileUsingSCF` come on par with `linalg::tileReductionUsingForall` which will be deprecated subsequently. Changes summary - `PartialReductionTilingInterface` changes : - `tileToPartialReduction` method needed to get the induction variables of the generated tile loops. This was needed to keep the generated code similar to `linalg::tileReductionUsingForall`, specifically to create a simplified access for slicing the intermediate partial results tensor when tiled in `num_threads` mode. - `getPartialResultTilePosition` methods needs the induction varialbes for the generated tile loops for the same reason above, and also needs the `tilingStrategy` to be passed in to generate correct code. The tests in `transform-tile-reduction.mlir` testing the `linalg::tileReductionUsingForall` have been moved over to test `scf::tileUsingSCF` with `ReductionTilingStrategy::PartialReductionOuterParallel` strategy. Some of the test that were doing further cyclic distribution of the transformed code from tiling are removed. Those seem like two separate transformation that were merged into one. Ideally that would need to happen when resolving the `scf.forall` rather than during tiling. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent c348b19 commit b03fcbe

File tree

8 files changed

+217
-221
lines changed

8 files changed

+217
-221
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
156156
/// corresponding pair of arrays. This is the inverse function of
157157
/// `getMixedValues`.
158158
std::pair<SmallVector<int64_t>, SmallVector<Value>>
159-
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
159+
decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
160160

161161
/// Helper to sort `values` according to matching `keys`.
162162
SmallVector<Value>

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def PartialReductionOpInterface :
404404
"Location ":$loc,
405405
"::mlir::ReductionTilingStrategy":$tilingStrategy,
406406
"ValueRange":$init,
407+
"ValueRange":$ivs,
407408
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
408409
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
409410
"const ::llvm::SetVector<unsigned> &":$reductionDims),
@@ -442,6 +443,8 @@ def PartialReductionOpInterface :
442443
/*args=*/(ins
443444
"::mlir::OpBuilder &":$b,
444445
"unsigned":$resultNumber,
446+
"ValueRange":$ivs,
447+
"ReductionTilingStrategy":$tilingStrategy,
445448
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
446449
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447450
"const ::mlir::SetVector<unsigned> &":$reductionDims,

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,23 +2864,41 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
28642864
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
28652865
SmallVector<OpFoldResult> tileSizes =
28662866
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2867-
FailureOr<linalg::ForallReductionTilingResult> result =
2868-
linalg::tileReductionUsingForall(
2869-
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2870-
numThreads, tileSizes, getMapping());
2867+
2868+
scf::SCFTilingOptions options;
2869+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
2870+
options.setReductionTilingStrategy(
2871+
ReductionTilingStrategy::PartialReductionOuterParallel);
2872+
if (!getNumThreads().empty()) {
2873+
options.setNumThreads(numThreads);
2874+
} else {
2875+
options.setTileSizes(tileSizes);
2876+
}
2877+
if (auto mapping = getMapping()) {
2878+
options.setMapping(mapping.value().getValue());
2879+
}
2880+
SmallVector<unsigned> reductionDims;
2881+
for (auto [idx, iteratorType] :
2882+
llvm::enumerate(target.getIteratorTypesArray()))
2883+
if (iteratorType == utils::IteratorType::reduction)
2884+
reductionDims.push_back(idx);
2885+
options.setReductionDims(reductionDims);
2886+
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
2887+
rewriter, cast<TilingInterface>(target.getOperation()), options);
28712888

28722889
if (failed(result)) {
28732890
auto diag = emitSilenceableError() << "could not tile reduction";
2874-
diag.attachNote(target.getLoc()) << "target operation";
28752891
return diag;
28762892
}
2893+
rewriter.replaceOp(target, result->replacements);
2894+
28772895
for (Value initValue : result->initialValues)
28782896
results.push_back(initValue.getDefiningOp());
2879-
for (auto parallelTiledOp : result->parallelTiledOps)
2897+
for (auto parallelTiledOp : result->tiledOps)
28802898
results.push_back(parallelTiledOp);
28812899
for (auto mergeOp : result->mergeOps)
28822900
results.push_back(mergeOp);
2883-
results.push_back(result->loops);
2901+
results.push_back(result->loops.front());
28842902
return DiagnosedSilenceableFailure::success();
28852903
}
28862904

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 127 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
328328
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
329329
//===----------------------------------------------------------------------===//
330330

331+
/// In a given set vector, get the position of a particular element.
332+
std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
333+
unsigned value) {
334+
for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
335+
if (reductionDim == value) {
336+
return index;
337+
}
338+
}
339+
return std::nullopt;
340+
}
341+
331342
/// Return an AffineMaps to use for the `outs` operands of the linalg op
332343
/// generated for partial results. The new AffineMap is the AffineMap of the
333344
/// untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,79 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
348359
return partialReductionMaps;
349360
}
350361

351-
/// Return the slice of the `initValue` to use as input to the partial reduction
352-
/// op generated.
353-
static Operation *getInitSliceForOuterReduction(
354-
OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
362+
struct InitSliceInfo {
363+
SmallVector<int64_t> resultShape;
364+
SmallVector<OpFoldResult> offsets;
365+
SmallVector<OpFoldResult> sizes;
366+
SmallVector<OpFoldResult> strides;
367+
};
368+
369+
/// Return the result type, offsets, sizes and strides of the slice of the
370+
/// `initValue` to use as input to the partial reduction op generated with
371+
/// outer reduction strategy.
372+
static InitSliceInfo getInitSliceInfoForOuterReduction(
373+
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
355374
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
356375
AffineMap partialReductionMap) {
357376
int64_t initRank = partialReductionMap.getNumResults();
358377
SmallVector<OpFoldResult> initOffsets, initSizes;
359-
SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
378+
Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
379+
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
380+
SmallVector<OpFoldResult> initStrides(initRank, one);
360381
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
361382
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
362383
if (reductionDims.contains(dim)) {
363-
initOffsets.push_back(b.getIndexAttr(0));
384+
initOffsets.push_back(zero);
364385
} else {
365386
initOffsets.push_back(offsets[dim]);
366387
}
367388
initSizes.push_back(sizes[dim]);
368389
}
369-
// TODO: Use SubsetExtractOpInterface here once available.
370-
auto extractSlice = b.create<tensor::ExtractSliceOp>(
371-
loc, initValue, initOffsets, initSizes, initStrides);
372-
return extractSlice;
390+
SmallVector<int64_t> resultShape;
391+
std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
392+
return {resultShape, initOffsets, initSizes, initStrides};
393+
}
394+
395+
/// Return the result type, offsets, sizes and strides of the slice of the
396+
/// `initValue` to use as input to the partial reduction op generated with
397+
/// outer parallel strategy.
398+
static InitSliceInfo getInitSliceInfoForOuterParallel(
399+
MLIRContext *context, ValueRange ivs, ArrayRef<OpFoldResult> offsets,
400+
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
401+
AffineMap partialReductionMap) {
402+
int64_t initRank = partialReductionMap.getNumResults();
403+
SmallVector<OpFoldResult> initOffsets, initSizes;
404+
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
405+
SmallVector<OpFoldResult> initStrides(initRank, one);
406+
SmallVector<OpFoldResult> resultShape;
407+
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
408+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
409+
if (std::optional<int> dimPos = getPositionIn(reductionDims, dim)) {
410+
initOffsets.push_back(ivs[dimPos.value()]);
411+
initSizes.push_back(one);
412+
} else {
413+
initOffsets.push_back(offsets[dim]);
414+
initSizes.push_back(sizes[dim]);
415+
resultShape.push_back(sizes[dim]);
416+
}
417+
}
418+
SmallVector<int64_t> staticShapes;
419+
std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
420+
return {staticShapes, initOffsets, initSizes, initStrides};
421+
}
422+
423+
static InitSliceInfo getInitSliceInfo(
424+
MLIRContext *context, ReductionTilingStrategy strategy, ValueRange ivs,
425+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
426+
const SetVector<unsigned> &reductionDims, AffineMap partialReductionMap) {
427+
if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
428+
return getInitSliceInfoForOuterReduction(
429+
context, offsets, sizes, reductionDims, partialReductionMap);
430+
}
431+
assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
432+
"unexpected ReductionTilingStrategy");
433+
return getInitSliceInfoForOuterParallel(context, ivs, offsets, sizes,
434+
reductionDims, partialReductionMap);
373435
}
374436

375437
/// External model implementation of PartialReductionInterface for
@@ -439,18 +501,11 @@ struct LinalgOpPartialReductionInterface
439501
return inits;
440502
}
441503

442-
FailureOr<TilingResult>
443-
tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
444-
ReductionTilingStrategy tilingStrategy,
445-
ValueRange init, ArrayRef<OpFoldResult> offsets,
446-
ArrayRef<OpFoldResult> sizes,
447-
const SetVector<unsigned> &reductionDims) const {
448-
if (tilingStrategy !=
449-
ReductionTilingStrategy::PartialReductionOuterReduction) {
450-
// TODO: Add support for `PartialReductionOuterParallel` strategy.
451-
return op->emitOpError("unsupported partial reduction tiling with "
452-
"`PartialReductionOuterParallel` strategy");
453-
}
504+
FailureOr<TilingResult> tileToPartialReduction(
505+
Operation *op, OpBuilder &b, Location loc,
506+
ReductionTilingStrategy tilingStrategy, ValueRange init, ValueRange ivs,
507+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
508+
const SetVector<unsigned> &reductionDims) const {
454509
OpBuilder::InsertionGuard guard(b);
455510
auto linalgOp = cast<LinalgOp>(op);
456511

@@ -459,7 +514,16 @@ struct LinalgOpPartialReductionInterface
459514

460515
// Step 1. Extend init maps to have reduction dimension dims, since we
461516
// are converting them to parallel dimensions.
462-
SmallVector<AffineMap> newInitMaps = partialReductionMaps;
517+
SmallVector<AffineMap> newInitMaps;
518+
if (tilingStrategy ==
519+
ReductionTilingStrategy::PartialReductionOuterReduction) {
520+
newInitMaps = llvm::to_vector(partialReductionMaps);
521+
} else {
522+
newInitMaps = llvm::map_to_vector(
523+
linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
524+
return linalgOp.getMatchingIndexingMap(&opOperand);
525+
});
526+
}
463527

464528
// Step 2a: Extract a slice of the input operands.
465529
SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -473,10 +537,17 @@ struct LinalgOpPartialReductionInterface
473537
SmallVector<Value, 1> tiledInits;
474538
for (auto [partialReductionMap, valueToTile] :
475539
llvm::zip_equal(partialReductionMaps, init)) {
476-
Operation *sliceOp =
477-
getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
478-
reductionDims, partialReductionMap);
479-
tiledInits.push_back(sliceOp->getResult(0));
540+
InitSliceInfo sliceInfo =
541+
getInitSliceInfo(b.getContext(), tilingStrategy, ivs, offsets, sizes,
542+
reductionDims, partialReductionMap);
543+
auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
544+
RankedTensorType sliceResultType = RankedTensorType::get(
545+
sliceInfo.resultShape, valueToTileType.getElementType(),
546+
valueToTileType.getEncoding());
547+
auto sliceOp = b.create<tensor::ExtractSliceOp>(
548+
loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes,
549+
sliceInfo.strides);
550+
tiledInits.push_back(sliceOp.getResult());
480551
generatedSlices.push_back(sliceOp);
481552
}
482553

@@ -491,19 +562,31 @@ struct LinalgOpPartialReductionInterface
491562
// Step 3. Change the reduction dim iterator types.
492563
SmallVector<utils::IteratorType> newIteratorTypes =
493564
linalgOp.getIteratorTypesArray();
494-
for (int dim : reductionDims)
495-
newIteratorTypes[dim] = utils::IteratorType::parallel;
565+
if (tilingStrategy ==
566+
ReductionTilingStrategy::PartialReductionOuterReduction) {
567+
for (int dim : reductionDims)
568+
newIteratorTypes[dim] = utils::IteratorType::parallel;
569+
}
496570

497571
// Step 4. Create the new generic op.
572+
Operation *partialReductionOp;
498573
auto resultTypes = ValueRange(tiledInits).getTypes();
499-
auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
500-
tiledInits, newMaps, newIteratorTypes);
501-
IRMapping mapping;
502-
op->getRegion(0).cloneInto(&genericOp.getRegion(),
503-
genericOp.getRegion().begin(), mapping);
574+
if (tilingStrategy ==
575+
ReductionTilingStrategy::PartialReductionOuterReduction) {
576+
auto genericOp = b.create<GenericOp>(
577+
loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
578+
IRMapping mapping;
579+
op->getRegion(0).cloneInto(&genericOp.getRegion(),
580+
genericOp.getRegion().begin(), mapping);
581+
partialReductionOp = genericOp.getOperation();
582+
} else {
583+
SmallVector<Value> operands = std::move(tiledInputs);
584+
llvm::append_range(operands, tiledInits);
585+
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
586+
}
504587
return TilingResult{
505-
{genericOp.getOperation()},
506-
llvm::map_to_vector(genericOp->getResults(),
588+
{partialReductionOp},
589+
llvm::map_to_vector(partialReductionOp->getResults(),
507590
[](OpResult r) -> Value { return r; }),
508591
generatedSlices};
509592
}
@@ -557,27 +640,19 @@ struct LinalgOpPartialReductionInterface
557640
}
558641

559642
LogicalResult getPartialResultTilePosition(
560-
Operation *op, OpBuilder &b, unsigned resultNumber,
561-
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
562-
const SetVector<unsigned> &reductionDims,
643+
Operation *op, OpBuilder &b, unsigned resultNumber, ValueRange ivs,
644+
ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
645+
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
563646
SmallVector<OpFoldResult> &resultOffsets,
564647
SmallVector<OpFoldResult> &resultSizes) const {
565648
auto linalgOp = cast<LinalgOp>(op);
566649
SmallVector<AffineMap> partialReductionMaps =
567650
getPartialResultAffineMaps(linalgOp, reductionDims);
568-
569-
for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
570-
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
571-
resultSizes.push_back(sizes[dim]);
572-
573-
if (llvm::is_contained(reductionDims, dim)) {
574-
// Reduction dims are reduced, and are always outputed in the same
575-
// place. So use offset 0 for them.
576-
resultOffsets.push_back(b.getIndexAttr(0));
577-
} else {
578-
resultOffsets.push_back(offsets[dim]);
579-
}
580-
}
651+
InitSliceInfo sliceInfo =
652+
getInitSliceInfo(b.getContext(), tilingStrategy, ivs, offsets, sizes,
653+
reductionDims, partialReductionMaps[resultNumber]);
654+
std::swap(resultOffsets, sliceInfo.offsets);
655+
std::swap(resultSizes, sliceInfo.sizes);
581656

582657
return success();
583658
}

0 commit comments

Comments
 (0)