-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface #92624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Kunwar Grover (Groverkss) ChangesPatch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92624.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
- let results = (outs TransformHandleTypeInterface:$fill_op,
+ let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
- let results = (outs TransformHandleTypeInterface:$fill_op,
+ let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$forall_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..308ce92e35520 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
- /// The op initializing the tensor used for partial reductions.
- Operation *initialOp;
+ /// Initial values used for partial reductions.
+ SmallVector<Value> initialValues;
/// The `scf.forall` operation that iterate over the tiles.
scf::ForallOp loops;
};
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..6d567171e185a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
- /// Initial op
- Operation *initialOp;
+ /// Initial values used for reduction.
+ SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
};
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..6fff7e0da538a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
operation reduction. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size.
}],
- /*retType=*/"FailureOr<Operation*>",
+ /*retType=*/"FailureOr<Value>",
/*methodName=*/"generateInitialTensorForPartialReduction",
/*args=*/(ins
"OpBuilder &":$b,
- "Location ":$loc,
+ "Location":$loc,
+ "int64_t":$resultNumber,
"ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<int>":$reductionDim),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9b3121774ab3a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
if (failed(result))
return emitDefaultSilenceableFailure(target);
- results.push_back(result->initialOp);
+ for (Value initValue : result->initialValues)
+ results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
diag.attachNote(target.getLoc()) << "target operation";
return diag;
}
- results.push_back(result->initialOp);
+ for (Value initValue : result->initialValues)
+ results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..e0394f852fcc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
PartialReductionOpInterface partialReductionIface =
llvm::cast<PartialReductionOpInterface>(op.getOperation());
- FailureOr<Operation *> reductionNeutralTensorOp =
+ assert(op->getNumResults() == 1 && "Multiple results not supported.");
+ FailureOr<Value> reductionNeutralTensor =
partialReductionIface.generateInitialTensorForPartialReduction(
- builder, builder.getLoc(), shape, {});
- assert(succeeded(reductionNeutralTensorOp));
- builder.create<scf::YieldOp>(
- reductionNeutralTensorOp.value()->getResult(0));
+ builder, builder.getLoc(), 0, shape, {});
+ assert(succeeded(reductionNeutralTensor));
+ builder.create<scf::YieldOp>(reductionNeutralTensor.value());
}
return ifOp.getResult(0);
}
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
- // PartialReductionOpInterface::generateInitialTensorForPartialReduction
- // needs to also support multiple DPS initial operands.
+ assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
Value spmdizedInitOperand =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..d4805611a68b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -657,9 +657,9 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
// Ops implementing PartialReductionOpInterface are not necessarily expected
- // to implement TilingInterface.. This cast is unsafe atm.
+ // to implement DestinationStyleOpInterface. This cast is unsafe atm.
// TODO: proper core mechanism to tie interfaces together.
- // TODO: this function requires a pair of interfaces ..
+ // TODO: this function requires a pair of interfaces.
auto destinationStyleOp =
dyn_cast<DestinationStyleOpInterface>(op.getOperation());
if (!destinationStyleOp)
@@ -671,9 +671,6 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
return b.notifyMatchFailure(op, "not a linalg op");
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- if (op->getNumResults() != 1)
- return b.notifyMatchFailure(
- op, "don't support ops with multiple results for now");
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
@@ -692,12 +689,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
op, "reduction dimension must be mapped to threads");
// 1. Create the inital tensor value.
- FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, numThreads,
- reductionDim);
- if (failed(identityTensor))
- return b.notifyMatchFailure(op,
- "cannot create a tensor of identity value.");
+ SmallVector<Value> initTensors;
+ initTensors.reserve(op->getNumResults());
+ for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+ FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
+ b, loc, idx, numThreads, reductionDim);
+ if (failed(initValue))
+ return b.notifyMatchFailure(
+ op, "cannot create a tensor of identity value for result " +
+ std::to_string(idx));
+ initTensors.push_back(initValue.value());
+ }
// Gather destination tensors.
SmallVector<Value> dest;
@@ -715,8 +717,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 2. Create the ForallOp with an empty region.
scf::ForallOp forallOp = b.create<scf::ForallOp>(
- loc, getAsOpFoldResult(materializedNonZeroNumThreads),
- (*identityTensor)->getResults(), mapping);
+ loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+ mapping);
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
// be nested under `forallOp`.
@@ -726,7 +728,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
tiledSizes);
- // 4. Clone the tileable op and update its destination operands to use the
+ // 4b. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
SmallVector<Value> tilingResults;
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +840,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 8. Return.
ForallReductionTilingResult results;
- results.initialOp = *identityTensor;
+ results.initialValues = initTensors;
results.loops = forallOp;
results.parallelTiledOp = tiledOp;
results.mergeOp = mergeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..edae03fec72a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
struct LinalgOpPartialReductionInterface
: public PartialReductionOpInterface::ExternalModel<
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
- FailureOr<Operation *> generateInitialTensorForPartialReduction(
- Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
- ArrayRef<int> reductionDims) const {
+ FailureOr<Value> generateInitialTensorForPartialReduction(
+ Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
+ ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
OpBuilder::InsertionGuard guard(b);
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
// loops. This could be controlled by user for more flexibility.
SmallVector<Operation *, 4> combinerOps;
- if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+ if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
+ combinerOps) ||
combinerOps.size() != 1)
return op->emitOpError("Failed to anaysis the reduction operation.");
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
"Failed to get an identity value for the reduction operation.");
ArrayRef<int64_t> oldShape =
- linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+ linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
// Calculate the new shape, we insert the new dimensions based on the index
// of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
newOutputShape.push_back(dim);
if (ShapedType::isDynamic(dim))
dynamicDims.push_back(b.create<tensor::DimOp>(
- loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+ loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
}
Value emptyTensor = b.create<tensor::EmptyOp>(
- loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
- dynamicDims);
+ loc, newOutputShape,
+ linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
auto identityTensor =
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
- return identityTensor.getOperation();
+ return identityTensor.getResult(0);
}
Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,58 @@ struct LinalgOpPartialReductionInterface
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);
- AffineMap oldOutputMap =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
- SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
- reductionDims.size());
-
- for (int idx : reductionDims)
- outputExpr[idx] = b.getAffineDimExpr(idx);
- int currExpr = 0;
- for (int idx : llvm::seq<int>(0, outputExpr.size())) {
- if (outputExpr[idx])
- continue;
- outputExpr[idx] = oldOutputMap.getResult(currExpr++);
+ // Step 1. Extend init maps to have reduction dimension dims, since we
+ // are converting them to parallel dimensions.
+ SmallVector<AffineMap> newInitMaps;
+ newInitMaps.reserve(linalgOp.getNumDpsInits());
+ for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+ // this with a for range loop when we have it.
+ AffineMap newMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+ for (int redPos : reductionDims) {
+ newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
+ newMap.getNumResults());
+ }
+ newInitMaps.push_back(newMap);
}
- // Step 1: Extract a slice of the input operands.
- SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
- SmallVector<Value, 4> tiledOperands = makeTiledShapes(
- b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ // Step 2a: Extract a slice of the input operands.
+ SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+ b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+
+ // Step 2b: Extract a slice of the init operands.
+ SmallVector<Value, 1> tiledInits;
+ for (Value valueToTile : init) {
+ SmallVector<OpFoldResult> initOffset(offsets.size(), b.getIndexAttr(0));
+ SmallVector<OpFoldResult> initStride(offsets.size(), b.getIndexAttr(1));
+ // TODO: Use SubsetExtractOpInterface here once available.
+ auto extractSlice = b.create<tensor::ExtractSliceOp>(
+ loc, valueToTile, initOffset, sizes, initStride);
+ tiledInits.push_back(extractSlice);
+ }
- // Step 2: Extract the accumulator operands
- SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
- SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- // TODO: use SubsetExtractOpInterface once it is available.
- Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
- sizes, strides);
+ // Update the indexing maps.
+ SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+ // Change the init maps.
+ for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+ // this with a for range loop when we have it.
+ OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
+ int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
+ newMaps[mapIdx] = newInitMaps[idx];
+ }
- // Step3. Create a generic op where the reduction dimensions are replaced
- // by a parallel dimension of the size of reduction.
+ // Step 3. Change the reduction dim iterator types.
SmallVector<utils::IteratorType> newIteratorTypes =
linalgOp.getIteratorTypesArray();
for (int dim : reductionDims)
newIteratorTypes[dim] = utils::IteratorType::parallel;
- SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
- newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
- linalgOp.getContext());
+
+ // Step 4. Create the new generic op.
auto genericOp =
- b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
- ValueRange({out}), newMaps, newIteratorTypes);
+ b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
+ tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
@@ -361,40 +376,53 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
- DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-
- // Then create a new reduction that only reduce the newly added dimensions
- // from the previous op.
- int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
- AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
- SmallVector<utils::IteratorType> reductionIteratorTypes;
- SmallVector<AffineExpr> exprs;
-
- for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
- if (reductionDimsSet.contains(i)) {
- reductionIteratorTypes.push_back(utils::IteratorType::reduction);
- } else {
- exprs.push_back(b.getAffineDimExpr(i));
- reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+ // Step 1. Recover the dims that actually need to be merged from the
+ // original operation. We can classify the original iterators as follows:
+ //
+ // parallel --> parallel
+ // reduction + not in reductionDims --> parallel (already reduced)
+ // reduction + in reductionDims --> reduction (will reduce now)
+ SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
+ utils::IteratorType::parallel);
+ for (int redIdx : reductionDims)
+ iterators[redIdx] = utils::IteratorType::reduction;
+
+ // Step 2. For each partial result, create a map to index it. This map
+ // is simply the indexing map for the original result with reductionDims
+ // appended (as produced in tileToPartialReduction).
+ int64_t numInits = linalgOp.getNumDpsInits();
+ SmallVector<AffineMap> indexingMaps(numInits * 2);
+ for (int idx : llvm::seq<int>(0, numInits)) {
+ AffineMap &inputMap = indexingMaps[idx];
+ AffineMap &outputMap = indexingMaps[numInits + idx];
+
+ outputMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+ inputMap = outputMap;
+ for (int redPos : reductionDims) {...
[truncated]
|
3851aba
to
8e890c8
Compare
// CHECK: func @reduction_tile_transpose | ||
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> | ||
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> | ||
// CHECK: scf.for | ||
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32> | ||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) | ||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldnt expect this to change... this should be NFC for the most part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored a bit of the implementation for multiple reduction dims as part of adding multiple results to linalg partial tiling interface. Before, it was trying to insert reduction dims in an ad-hoc way. Now, it always inserts reduction dims at end. This means that when reducing them in the final loop, they are all contiguous. Also, the order is deterministic (the old implementation would not work for multiple results).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I think this looks OK. I am a bit fuzzy on the implementation details completely, but this seems fine to me.
) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624) Signed-off-by: Lubo Litchev <[email protected]>
This patch adds support for reducing operations with multiple results using PartialReductionOpInterface. Also adds an implementation of PartialReductionOpInterface for multiple results for linalg.generic.