Skip to content

Commit 0cc40fc

Browse files
committed
[mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface
1 parent 83d8a8c commit 0cc40fc

File tree

4 files changed

+196
-78
lines changed

4 files changed

+196
-78
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
427427
/*defaultImplementation=*/[{
428428
return failure();
429429
}]
430+
>,
431+
InterfaceMethod<
432+
/*desc=*/[{
433+
Method to return the position of the partial result tile computed by
434+
the tiled operation. This is same as
435+
TilingInterface:::getResultTilePosition, but determines the result
436+
tile position for partial reduction.
437+
}],
438+
/*retType=*/"::llvm::LogicalResult",
439+
/*methodName=*/"getPartialResultTilePosition",
440+
/*args=*/(ins
441+
"::mlir::OpBuilder &":$b,
442+
"unsigned":$resultNumber,
443+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
444+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
445+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
446+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
447+
"::mlir::ArrayRef<int>":$reductionDims),
448+
/*methodBody=*/"",
449+
/*defaultImplementation=*/[{
450+
return failure();
451+
}]
430452
>
431453
];
432454
}

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 105 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,20 @@ struct LinalgOpTilingInterface
324324
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325325
//===----------------------------------------------------------------------===//
326326

327-
/// External model implementation of PartialReductionInterface for LinalgOps.
327+
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
328+
ArrayRef<int> reductionDims,
329+
unsigned resultNumber) {
330+
AffineMap map =
331+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
332+
for (int redPos : reductionDims) {
333+
map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
334+
map.getNumResults());
335+
}
336+
return map;
337+
}
338+
339+
/// External model implementation of PartialReductionInterface for
340+
/// LinalgOps.
328341
template <typename LinalgOpTy>
329342
struct LinalgOpPartialReductionInterface
330343
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +351,24 @@ struct LinalgOpPartialReductionInterface
338351
if (linalgOp.hasPureBufferSemantics())
339352
return op->emitOpError("expected operation to have tensor semantics");
340353

354+
// LinalgOp implements TilingInterface.
355+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
356+
SmallVector<OpFoldResult> shape =
357+
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
358+
[](Range x) { return x.size; });
359+
360+
SmallVector<OpFoldResult> tiledShape;
361+
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
362+
if (isZeroIndex(tileSize)) {
363+
tiledShape.push_back(dimSize);
364+
} else {
365+
tiledShape.push_back(tileSize);
366+
}
367+
}
368+
341369
SmallVector<Value> inits;
342370
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
343371
++initIdx) {
344-
// Insert the new parallel dimension based on the index of the reduction
345-
// loops. This could be controlled by user for more flexibility.
346372
SmallVector<Operation *, 4> combinerOps;
347373
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
348374
combinerOps) ||
@@ -355,33 +381,19 @@ struct LinalgOpPartialReductionInterface
355381
return op->emitOpError(
356382
"Failed to get an identity value for the reduction operation.");
357383

358-
ArrayRef<int64_t> oldShape =
359-
linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
360-
361-
// Calculate the new shape, we insert the new dimensions based on the
362-
// index of the reduction dimensions.
363-
SmallVector<int64_t> newOutputShape;
364-
SmallVector<Value> dynamicDims;
365-
int64_t currReductionDims = 0;
366-
DenseSet<int> reductionDimsSet(reductionDims.begin(),
367-
reductionDims.end());
368-
for (int64_t idx :
369-
llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
370-
if (reductionDimsSet.contains(idx)) {
371-
dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
372-
currReductionDims++;
373-
continue;
374-
}
375-
int64_t oldIdx = idx - currReductionDims;
376-
int64_t dim = oldShape[oldIdx];
377-
newOutputShape.push_back(dim);
378-
if (ShapedType::isDynamic(dim))
379-
dynamicDims.push_back(b.create<tensor::DimOp>(
380-
loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
384+
// Append the new partial result dimensions.
385+
AffineMap partialMap =
386+
getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
387+
SmallVector<OpFoldResult> partialResultShape;
388+
for (AffineExpr dimExpr : partialMap.getResults()) {
389+
auto dim = cast<AffineDimExpr>(dimExpr);
390+
partialResultShape.push_back(tiledShape[dim.getPosition()]);
381391
}
382-
Value emptyTensor = b.create<tensor::EmptyOp>(
383-
loc, newOutputShape,
384-
linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
392+
393+
Type elType =
394+
getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
395+
Value emptyTensor =
396+
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
385397
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
386398
auto identityTensor =
387399
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +419,7 @@ struct LinalgOpPartialReductionInterface
407419
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408420
// this with a for range loop when we have it.
409421
AffineMap newMap =
410-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
411-
for (int redPos : reductionDims) {
412-
newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
413-
newMap.getNumResults());
414-
}
422+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
415423
newInitMaps.push_back(newMap);
416424
}
417425

@@ -476,29 +484,74 @@ struct LinalgOpPartialReductionInterface
476484
Location loc, ValueRange partialReduce,
477485
ArrayRef<int> reductionDims) const {
478486
auto linalgOp = cast<LinalgOp>(op);
479-
SmallVector<int64_t> reductionDimsInt64(reductionDims);
480-
auto reduction = b.create<linalg::ReduceOp>(
481-
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
482-
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483-
int64_t numInits = linalgOp.getNumDpsInits();
484-
SmallVector<Value> yieldedValues;
485-
for (int idx : llvm::seq<int>(0, numInits)) {
487+
488+
// Permute the reduction dims as permuted by the partial result map.
489+
490+
int64_t numInits = linalgOp.getNumDpsInits();
491+
SmallVector<Operation *> mergeOperations;
492+
SmallVector<Value> replacements;
493+
for (int idx : llvm::seq(numInits)) {
494+
// linalg.reduce's iteration space is the result's iteration space (and
495+
// not the operations iteration space). To account for this, permute the
496+
// reduction dimensions based on the partial result map.
497+
AffineMap partialMap =
498+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
499+
SmallVector<int64_t> partialReductionDims;
500+
for (auto [resultNum, dimExpr] :
501+
llvm::enumerate(partialMap.getResults())) {
502+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
503+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
504+
partialReductionDims.push_back(resultNum);
505+
}
506+
}
507+
508+
Value partialResult = partialReduce[idx];
509+
Value init = linalgOp.getDpsInits()[idx];
510+
511+
auto reduction = b.create<linalg::ReduceOp>(
512+
loc, partialResult, init, partialReductionDims,
513+
[&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486514
// Get the combiner op.
487515
SmallVector<Operation *, 4> combinerOps;
488516
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
489517
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
490518
// Combine the input at idx and output at numInits + idx.
491-
clonedReductionOp->setOperand(0, inputs[idx]);
492-
clonedReductionOp->setOperand(1, inputs[numInits + idx]);
493-
// Yield.
494-
yieldedValues.push_back(clonedReductionOp->getResult(0));
495-
}
496-
b.create<linalg::YieldOp>(loc, yieldedValues);
497-
});
498-
return MergeResult{
499-
{reduction.getOperation()},
500-
llvm::map_to_vector(reduction->getResults(),
501-
[](OpResult r) -> Value { return r; })};
519+
clonedReductionOp->setOperand(0, inputs[0]);
520+
clonedReductionOp->setOperand(1, inputs[1]);
521+
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
522+
});
523+
524+
mergeOperations.push_back(reduction);
525+
replacements.push_back(reduction->getResult(0));
526+
}
527+
528+
return MergeResult{mergeOperations, replacements};
529+
}
530+
531+
LogicalResult getPartialResultTilePosition(
532+
Operation *op, OpBuilder &b, unsigned resultNumber,
533+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
534+
SmallVector<OpFoldResult> &resultOffsets,
535+
SmallVector<OpFoldResult> &resultSizes,
536+
ArrayRef<int> reductionDims) const {
537+
auto linalgOp = cast<LinalgOp>(op);
538+
539+
AffineMap partialMap =
540+
getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
541+
for (AffineExpr dimExpr : partialMap.getResults()) {
542+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
543+
resultSizes.push_back(sizes[dim]);
544+
545+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
546+
// Reduction dims are reduced, and are always outputed in the same
547+
// place. So use offset 0 for them.
548+
resultOffsets.push_back(b.getIndexAttr(0));
549+
} else {
550+
resultOffsets.push_back(offsets[dim]);
551+
}
552+
}
553+
554+
return success();
502555
}
503556
};
504557

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -656,21 +656,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
656656
resultOffset, resultSize);
657657
case scf::SCFTilingOptions::ReductionTilingStrategy::
658658
PartialReductionOuterReduction: {
659-
// TODO: This does not work for non identity accesses to the result tile.
660-
// The proper fix is to add a getPartialResultTilePosition method to
661-
// PartialReductionOpInterface.
662-
resultOffset =
663-
SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
664-
for (size_t i = 0; i < offsets.size(); i++) {
665-
resultSize.push_back(
666-
tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
659+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
660+
if (!redOp) {
661+
return rewriter.notifyMatchFailure(
662+
op, "PartialReductionOuterReduction tiling strategy is only supported"
663+
"for operations implementing PartialReductionOpInterface");
667664
}
668-
return success();
665+
// Get reduction dimensions.
666+
// TODO: PartialReductionOpInterface should really query TilingInterface
667+
// itself and find reduction dimensions.
668+
SmallVector<int> reductionDims;
669+
for (auto [idx, iteratorType] :
670+
llvm::enumerate(op.getLoopIteratorTypes())) {
671+
if (iteratorType == utils::IteratorType::reduction)
672+
reductionDims.push_back(idx);
673+
}
674+
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
675+
resultOffset, resultSize,
676+
reductionDims);
677+
}
669678
default:
670679
return rewriter.notifyMatchFailure(op,
671680
"unhandled reduction tiling strategy");
672681
}
673-
}
674682
}
675683

676684
static FailureOr<MergeResult>

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} {
3232
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
3333
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
3434
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
35-
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
36-
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
35+
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
3736
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
3837
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
3938
// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
@@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} {
8180
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
8281
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
8382
// CHECK: func @reduction_tile_transpose
84-
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
85-
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
83+
// CHECK: tensor.empty(%{{.*}}) : tensor<?x5xf32>
84+
// CHECK: linalg.fill {{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
8685
// CHECK: scf.for
87-
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
86+
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
8887
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
89-
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
90-
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
88+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
89+
// CHECK: scf.yield {{.*}} : tensor<?x5xf32>
9190
// CHECK: }
9291
// CHECK: linalg.reduce
9392
// CHECK: return
@@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} {
129128
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
130129
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
131130
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
132-
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
133-
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
131+
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
134132
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
135133
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
136134
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} {
183181
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
184182
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
185183
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
186-
// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
187-
// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
188-
// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
184+
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor<?x?x5xf32>
189185
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
190186
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
191187
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} {
243239
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
244240
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
245241
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
246-
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
247-
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
242+
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
248243
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
249244
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
250245
// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<
422417
module attributes {transform.with_named_sequence} {
423418
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
424419
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
425-
%1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
426-
by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
420+
%1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0
421+
by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
427422
transform.yield
428423
}
429424
}
@@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} {
444439
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
445440
// CHECK: linalg.reduce
446441
// CHECK: arith.addf
442+
// CHECK: linalg.reduce
447443
// CHECK: arith.maximumf
444+
445+
// -----
446+
447+
func.func @reduction_tile_multi_dim_transpose(%arg0: tensor<?x?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
448+
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
449+
affine_map<(d0, d1, d2) -> (d2, d0)>],
450+
iterator_types = ["parallel", "reduction", "parallel"]}
451+
ins(%arg0 : tensor<?x?x?xf32>)
452+
outs(%out : tensor<?x?xf32>) {
453+
^bb0(%arg7: f32, %arg9: f32):
454+
%42 = arith.addf %arg7, %arg9 : f32
455+
linalg.yield %42 : f32
456+
} -> tensor<?x?xf32>
457+
return %red : tensor<?x?xf32>
458+
}
459+
460+
module attributes {transform.with_named_sequence} {
461+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
462+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
463+
%1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
464+
by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
465+
transform.yield
466+
}
467+
}
468+
469+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
470+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
471+
// CHECK: func @reduction_tile_multi_dim_transpose
472+
// CHECK: tensor.empty(%{{.*}}) : tensor<?x?x5xf32>
473+
// CHECK: linalg.fill {{.*}} : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
474+
// CHECK: scf.for
475+
// CHECK: %[[K:.*]] = affine.min
476+
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?x?xf32>
477+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?x?xf32>) outs(%[[EXT]] : tensor<?x?x?xf32>)
478+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x5xf32>
479+
// CHECK: scf.yield {{.*}} : tensor<?x?x5xf32>
480+
// CHECK: }
481+
// CHECK: linalg.reduce
482+
// CHECK: return

0 commit comments

Comments
 (0)