Skip to content

[mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface #120465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the partial result tile computed by
the tiled operation. This is same as
TilingInterface:::getResultTilePosition, but determines the result
tile position for partial reduction.
}],
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getPartialResultTilePosition",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
"::mlir::ArrayRef<int>":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>
];
}
Expand Down
34 changes: 21 additions & 13 deletions mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
SymbolTableCollection &symbolTable) {
for (const MeshSharding& sharding : operandShardings) {
for (const MeshSharding &sharding : operandShardings) {
if (sharding) {
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
}

for (const MeshSharding& sharding : resultShardings) {
for (const MeshSharding &sharding : resultShardings) {
if (sharding) {
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
Expand All @@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
LinalgOp op, int operandNumber, Value spmdizedOperand,
ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
meshOp.getSymName(), reductionMeshAxes, builder);
Value zero = builder.create<arith::ConstantIndexOp>(0);
Expand All @@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
SmallVector<OpFoldResult> shape =
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
PartialReductionOpInterface partialReductionIface =
llvm::cast<PartialReductionOpInterface>(op.getOperation());
assert(op->getNumResults() == 1 && "Multiple results not supported.");
FailureOr<SmallVector<Value>> reductionNeutralTensor =
partialReductionIface.generateInitialTensorForPartialReduction(
builder, builder.getLoc(), shape, {});
assert(succeeded(reductionNeutralTensor));
builder.create<scf::YieldOp>(reductionNeutralTensor.value());

SmallVector<Operation *> combinerOps;
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
assert(combinerOps.size() == 1);
std::optional<TypedAttr> neutralEl =
arith::getNeutralElement(combinerOps[0]);

Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
neutralEl.value().getType());
Value constant =
builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
.getResult(0);

builder.create<scf::YieldOp>(fill);
}
return ifOp.getResult(0);
}
Expand All @@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
Value spmdizedInitOperand =
spmdizationMap.lookup(op->getOperands()[operandIdx]);
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
return newOperands;
}

Expand Down
165 changes: 113 additions & 52 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//

/// External model implementation of PartialReductionInterface for LinalgOps.
/// Return an AffineMap for a partial result for the given result number,
/// assuming the partial tiling strategy is outer-reduction loop +
/// inner-parallel tile. The returned AffineMap can be used as the replacement
/// AffineMap for the inner-parallel tile linalg op for the given result number.
///
/// The new AffineMap is the old AffineMap with reduction dimensions appended
/// at end.
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
ArrayRef<int> reductionDims,
unsigned resultNumber) {
AffineMap map =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
for (int redPos : reductionDims) {
map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
map.getNumResults());
}
return map;
}

/// External model implementation of PartialReductionInterface for
/// LinalgOps.
template <typename LinalgOpTy>
struct LinalgOpPartialReductionInterface
: public PartialReductionOpInterface::ExternalModel<
Expand All @@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
if (linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");

// LinalgOp implements TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
SmallVector<OpFoldResult> shape =
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
[](Range x) { return x.size; });

SmallVector<OpFoldResult> tiledShape;
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
if (isZeroIndex(tileSize)) {
tiledShape.push_back(dimSize);
} else {
tiledShape.push_back(tileSize);
}
}

SmallVector<Value> inits;
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
++initIdx) {
// Insert the new parallel dimension based on the index of the reduction
// loops. This could be controlled by user for more flexibility.
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
combinerOps) ||
Expand All @@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");

ArrayRef<int64_t> oldShape =
linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));

// Calculate the new shape, we insert the new dimensions based on the
// index of the reduction dimensions.
SmallVector<int64_t> newOutputShape;
SmallVector<Value> dynamicDims;
int64_t currReductionDims = 0;
DenseSet<int> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
for (int64_t idx :
llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
if (reductionDimsSet.contains(idx)) {
dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
currReductionDims++;
continue;
}
int64_t oldIdx = idx - currReductionDims;
int64_t dim = oldShape[oldIdx];
newOutputShape.push_back(dim);
if (ShapedType::isDynamic(dim))
dynamicDims.push_back(b.create<tensor::DimOp>(
loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
// Append the new partial result dimensions.
AffineMap partialMap =
getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) {
auto dim = cast<AffineDimExpr>(dimExpr);
partialResultShape.push_back(tiledShape[dim.getPosition()]);
}
Value emptyTensor = b.create<tensor::EmptyOp>(
loc, newOutputShape,
linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);

Type elType =
getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
Value emptyTensor =
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
auto identityTensor =
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
Expand All @@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
// this with a for range loop when we have it.
AffineMap newMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
for (int redPos : reductionDims) {
newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
newMap.getNumResults());
}
getPartialResultAffineMap(linalgOp, reductionDims, idx);
newInitMaps.push_back(newMap);
}

Expand Down Expand Up @@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
Location loc, ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
SmallVector<int64_t> reductionDimsInt64(reductionDims);
auto reduction = b.create<linalg::ReduceOp>(
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Value> yieldedValues;
for (int idx : llvm::seq<int>(0, numInits)) {

// Permute the reduction dims as permuted by the partial result map.

int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Operation *> mergeOperations;
SmallVector<Value> replacements;
for (int idx : llvm::seq(numInits)) {
// linalg.reduce's iteration space is the tiled result's iteration space
// (and not the tiled operation's iteration space). To account for this,
// permute the reduction dimensions based on the partial result map of the
// tiled result.
AffineMap partialMap =
getPartialResultAffineMap(linalgOp, reductionDims, idx);
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] :
llvm::enumerate(partialMap.getResults())) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
partialReductionDims.push_back(resultNum);
}
}

Value partialResult = partialReduce[idx];
Value init = linalgOp.getDpsInits()[idx];

auto reduction = b.create<linalg::ReduceOp>(
loc, partialResult, init, partialReductionDims,
[&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
// Get the combiner op.
SmallVector<Operation *, 4> combinerOps;
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
// Combine the input at idx and output at numInits + idx.
clonedReductionOp->setOperand(0, inputs[idx]);
clonedReductionOp->setOperand(1, inputs[numInits + idx]);
// Yield.
yieldedValues.push_back(clonedReductionOp->getResult(0));
}
b.create<linalg::YieldOp>(loc, yieldedValues);
});
return MergeResult{
{reduction.getOperation()},
llvm::map_to_vector(reduction->getResults(),
[](OpResult r) -> Value { return r; })};
clonedReductionOp->setOperand(0, inputs[0]);
clonedReductionOp->setOperand(1, inputs[1]);
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
});

mergeOperations.push_back(reduction);
replacements.push_back(reduction->getResult(0));
}

return MergeResult{mergeOperations, replacements};
}

LogicalResult getPartialResultTilePosition(
Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);

AffineMap partialMap =
getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
for (AffineExpr dimExpr : partialMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultSizes.push_back(sizes[dim]);

if (llvm::find(reductionDims, dim) != reductionDims.end()) {
// Reduction dims are reduced, and are always outputed in the same
// place. So use offset 0 for them.
resultOffsets.push_back(b.getIndexAttr(0));
} else {
resultOffsets.push_back(offsets[dim]);
}
}

return success();
}
};

Expand Down
28 changes: 18 additions & 10 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
resultOffset, resultSize);
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
// TODO: This does not work for non identity accesses to the result tile.
// The proper fix is to add a getPartialResultTilePosition method to
// PartialReductionOpInterface.
resultOffset =
SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
for (size_t i = 0; i < offsets.size(); i++) {
resultSize.push_back(
tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) {
return rewriter.notifyMatchFailure(
op, "PartialReductionOuterReduction tiling strategy is only supported"
"for operations implementing PartialReductionOpInterface");
}
return success();
// Get reduction dimensions.
// TODO: PartialReductionOpInterface should really query TilingInterface
// itself and find reduction dimensions.
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize,
reductionDims);
}
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
}
}

static FailureOr<MergeResult>
Expand Down
Loading
Loading