Skip to content

Commit 6b331ce

Browse files
committed
[mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp. - Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position. - Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
1 parent 282b56f commit 6b331ce

File tree

8 files changed

+549
-40
lines changed

8 files changed

+549
-40
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6363
The method returns the operation that is the tiled
6464
implementation.
6565
}],
66-
/*retType=*/"FailureOr<TilingResult>",
66+
/*retType=*/"FailureOr<::mlir::TilingResult>",
6767
/*methodName=*/"getTiledImplementation",
6868
/*args=*/(ins
6969
"OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
8282
by the tiled implementation. Expects the same `offsets` and `sizes` as
8383
used to obtain the tiled implementation of the operation.
8484
}],
85-
/*retType=*/"LogicalResult",
85+
/*retType=*/"::mlir::LogicalResult",
8686
/*methodName=*/"getResultTilePosition",
8787
/*args=*/(ins
8888
"OpBuilder &":$b,
8989
"unsigned":$resultNumber,
9090
"ArrayRef<OpFoldResult> ":$offsets,
9191
"ArrayRef<OpFoldResult> ":$sizes,
92-
"SmallVector<OpFoldResult> &":$resultOffsets,
93-
"SmallVector<OpFoldResult> &":$resultSizes),
92+
"SmallVectorImpl<OpFoldResult> &":$resultOffsets,
93+
"SmallVectorImpl<OpFoldResult> &":$resultSizes),
94+
/*methodBody=*/"",
95+
/*defaultImplementation=*/[{
96+
return failure();
97+
}]
98+
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return the position of iteration domain tile computed by the
102+
tiled operation.
103+
}],
104+
/*retType=*/"::mlir::LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromOperandTile",
106+
/*args=*/(ins
107+
"OpBuilder &":$b,
108+
"unsigned":$operandNumber,
109+
"ArrayRef<OpFoldResult> ":$offsets,
110+
"ArrayRef<OpFoldResult> ":$sizes,
111+
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
94113
/*methodBody=*/"",
95114
/*defaultImplementation=*/[{
96115
return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
119138
iteration space).
120139
- `sizes` provides the size of the tile.
121140
}],
122-
/*retType=*/"FailureOr<TilingResult>",
141+
/*retType=*/"FailureOr<::mlir::TilingResult>",
123142
/*methodName=*/"generateResultTileValue",
124143
/*args=*/(ins
125144
"OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
131150
return failure();
132151
}]
133152
>,
153+
InterfaceMethod<
154+
/*desc=*/[{
155+
Method to generate the tiled implementation of an operation from
156+
operand tile position.
157+
158+
Generates the IR that computes the tiled implementation of an
159+
operation from operand tile. The `offsets` and `sizes`
160+
describe the tile of the operand required. This is different from
161+
`getTiledImplementation` which generates the tiled
162+
implementation of the operation given a tile of the
163+
iteration space. This method generates a tiled
164+
implementation of the operation based on the tile of the
165+
operand required. This method enables consumer fusion by using
166+
tile and fuse. The method returns failure if the operation
167+
can't be tiled to generate the operand tile. In practical terms
168+
this implies it cannot be tiled and fused with its producers.
169+
170+
- `offsets` provides the offset of the tile in the coordinate system
171+
of the original iteration space, i.e., if an iteration space
172+
dimension had non-zero offset, it must be included in the offset
173+
provided here (as opposed to zero-based offset "relative" to the
174+
iteration space).
175+
- `sizes` provides the size of the tile.
176+
}],
177+
/*retType=*/"FailureOr<::mlir::TilingResult>",
178+
/*methodName=*/"getTiledImplementationFromOperandTile",
179+
/*args=*/(ins
180+
"OpBuilder &":$b,
181+
"unsigned":$operandNumber,
182+
"ArrayRef<OpFoldResult>":$offsets,
183+
"ArrayRef<OpFoldResult>":$sizes),
184+
/*methodBody=*/"",
185+
/*defaultImplementation=*/[{
186+
return failure();
187+
}]
188+
>,
134189
InterfaceMethod<
135190
/*desc=*/[{
136191
Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
142197
transformations are done, this method can be used to lower to scalar
143198
code that can then be lowered to LLVM or SPIR-V dialects.
144199
}],
145-
/*retType=*/"LogicalResult",
200+
/*retType=*/"::mlir::LogicalResult",
146201
/*methodName=*/"generateScalarImplementation",
147202
/*args=*/(ins
148203
"OpBuilder &":$b,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
24252425

24262426
LogicalResult SoftmaxOp::getResultTilePosition(
24272427
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428-
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429-
SmallVector<OpFoldResult> &resultSizes) {
2428+
ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
2429+
SmallVectorImpl<OpFoldResult> &resultSizes) {
24302430
if (resultNumber == 0) {
24312431
resultOffsets.assign(offsets.begin(), offsets.end());
24322432
resultSizes.assign(sizes.begin(), sizes.end());

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113-
// Instantiate the tiled implementation of the operation.
113+
/// Instantiate the tiled implementation of the operation.
114114
FailureOr<TilingResult>
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135-
// Return the details of the output tile generated by the tiled
136-
// implementation.
135+
void
136+
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137+
ArrayRef<OpFoldResult> offsets,
138+
ArrayRef<OpFoldResult> sizes,
139+
SmallVectorImpl<OpFoldResult> &mappedOffsets,
140+
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141+
unsigned numLoops = linalgOp.getNumLoops();
142+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
143+
mappedOffsets.resize(numLoops);
144+
mappedSizes.resize(numLoops);
145+
if (!indexingMap.isPermutation()) {
146+
SmallVector<Range> iterationDomain =
147+
tilingInterfaceOp.getIterationDomain(b);
148+
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
149+
mappedOffsets[index] = value.offset;
150+
mappedSizes[index] = value.size;
151+
}
152+
}
153+
for (const auto &&[index, value] :
154+
llvm::enumerate(indexingMap.getResults())) {
155+
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
156+
mappedOffsets[dimPosition] = offsets[index];
157+
mappedSizes[dimPosition] = sizes[index];
158+
}
159+
}
160+
161+
/// Return the details of the output tile generated by the tiled
162+
/// implementation.
163+
LogicalResult getIterationDomainTileFromOperandTile(
164+
Operation *op, OpBuilder &b, unsigned operandNumber,
165+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168+
auto linalgOp = cast<LinalgOp>(op);
169+
170+
// Check that the indexing map used for the operand is a projected
171+
// permutation. This could be relaxed with a more general approach that can
172+
// map the offsets and sizes from the operand to iteration space tiles
173+
// (filling in full extent for dimensions not used to access the result).
174+
AffineMap indexingMap =
175+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
176+
if (!indexingMap.isProjectedPermutation()) {
177+
return emitError(op->getLoc(),
178+
"unhandled get iter domain position when operand is not "
179+
"accessed using a permuted projection");
180+
}
181+
182+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
183+
iterDomainOffsets, iterDomainSizes);
184+
return success();
185+
}
186+
187+
/// Return the details of the output tile generated by the tiled
188+
/// implementation.
137189
LogicalResult
138190
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
139191
ArrayRef<OpFoldResult> offsets,
140192
ArrayRef<OpFoldResult> sizes,
141-
SmallVector<OpFoldResult> &resultOffsets,
142-
SmallVector<OpFoldResult> &resultSizes) const {
193+
SmallVectorImpl<OpFoldResult> &resultOffsets,
194+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
143195
Location loc = op->getLoc();
144196
LinalgOp linalgOp = cast<LinalgOp>(op);
145197

@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
160212
return success();
161213
}
162214

215+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
216+
Operation *op, OpBuilder &b, unsigned operandNumber,
217+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219+
auto tilingInterfaceOp = cast<TilingInterface>(op);
220+
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
221+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222+
return emitError(
223+
op->getLoc(),
224+
"unable to obtain the iter domain position of the operation.");
225+
}
226+
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
227+
mappedSizes);
228+
}
229+
163230
FailureOr<TilingResult>
164231
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165232
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177244
"unhandled tiled implementation generation when result is not "
178245
"accessed using a permuted projection");
179246
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
247+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
249+
mappedOffsets, mappedSizes);
182250
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200251
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
252+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
253+
254+
if (failed(tilingResult))
255+
return failure();
256+
203257
if (tilingResult->tiledOps.size() != 1)
204258
return op->emitOpError("failed to generate tiled implementation");
205259

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
6161
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
6262
ArrayRef<OpFoldResult> offsets,
6363
ArrayRef<OpFoldResult> sizes,
64-
SmallVector<OpFoldResult> &resultOffsets,
65-
SmallVector<OpFoldResult> &resultSizes) const {
64+
SmallVectorImpl<OpFoldResult> &resultOffsets,
65+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
6666
resultOffsets.assign(offsets.begin(), offsets.end());
6767
resultSizes.assign(sizes.begin(), sizes.end());
6868
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
199199
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200200
ArrayRef<OpFoldResult> offsets,
201201
ArrayRef<OpFoldResult> sizes,
202-
SmallVector<OpFoldResult> &resultOffsets,
203-
SmallVector<OpFoldResult> &resultSizes) const {
202+
SmallVectorImpl<OpFoldResult> &resultOffsets,
203+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
204204
// The iteration domain is over outer dimensions of packed layout. In this
205205
// context, the outer dimensions of `resultOffsets` are `offsets`. The
206206
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
452452
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
453453
ArrayRef<OpFoldResult> offsets,
454454
ArrayRef<OpFoldResult> sizes,
455-
SmallVector<OpFoldResult> &resultOffsets,
456-
SmallVector<OpFoldResult> &resultSizes) const {
455+
SmallVectorImpl<OpFoldResult> &resultOffsets,
456+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
457457
resultOffsets = llvm::to_vector(offsets);
458458
resultSizes = llvm::to_vector(sizes);
459459
return success();
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
2+
3+
#map = affine_map<()[s0] -> (64 ceildiv s0)>
4+
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
5+
#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
6+
// CHECK-LABEL: func.func @fuse_tileable_consumer
7+
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
8+
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
9+
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
10+
func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
11+
// CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
12+
%0 = tensor.empty(%arg0) : tensor<?xf32>
13+
%1 = affine.apply #map()[%arg0]
14+
// CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
15+
%2 = tensor.empty() : tensor<64xf32>
16+
// CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
17+
%3 = tensor.empty() : tensor<64xf32>
18+
// CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
19+
%4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
20+
%6 = affine.apply #map1(%arg3)[%arg0]
21+
%7 = affine.min #map2(%arg3)[%arg0]
22+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
23+
%extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
24+
// CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
25+
%8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
26+
27+
// CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
28+
// CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
29+
// CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
30+
31+
scf.forall.in_parallel {
32+
// CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
33+
// CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
34+
tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
35+
}
36+
} {"containing"}
37+
// CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
38+
%5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
39+
// CHECK: return %[[RES]]#1
40+
return %5 : tensor<64xf32>
41+
}
42+
// -----
43+
44+
#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
45+
#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
46+
#map2 = affine_map<(d0) -> (d0 * 50)>
47+
#map3 = affine_map<(d0) -> (d0 * 16)>
48+
#map4 = affine_map<(d0, d1) -> (d0, d1)>
49+
#map5 = affine_map<(d0, d1) -> (d1, d0)>
50+
// CHECK-LABEL: func.func @fuse_consumer_multi_output
51+
// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
52+
// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
53+
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
54+
func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
55+
%cst = arith.constant 0.000000e+00 : f32
56+
// CHECK: %[[INIT:.*]] = linalg.fill
57+
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
58+
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
59+
%1 = tensor.empty() : tensor<123x789xf32>
60+
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
61+
%2 = tensor.empty() : tensor<789x123xf32>
62+
// CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
63+
%3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
64+
%5 = affine.min #map(%arg3)
65+
%6 = affine.min #map1(%arg4)
66+
%7 = affine.apply #map2(%arg3)
67+
%8 = affine.apply #map3(%arg4)
68+
%9 = affine.apply #map2(%arg3)
69+
%10 = affine.apply #map3(%arg4)
70+
// CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
71+
%extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
72+
// CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
73+
%extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
74+
// CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
75+
%extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
76+
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
77+
%11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
78+
79+
// CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
80+
// CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
81+
// CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
82+
83+
%12 = affine.apply #map2(%arg3)
84+
%13 = affine.apply #map3(%arg4)
85+
scf.forall.in_parallel {
86+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[[LOOP_ARG1]]
87+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#1 into %[[LOOP_ARG2]]
88+
// CHECK: tensor.parallel_insert_slice %[[MATMUL_RES]] into %[[LOOP_ARG0]]
89+
tensor.parallel_insert_slice %11 into %arg5[%12, %13] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<123x789xf32>
90+
}
91+
} {"containing"}
92+
// CHECK: %[[ORI_OUTPUT:.*]]:2 = linalg.generic
93+
%4:2 = linalg.generic {"consumer", indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<123x789xf32>) outs(%1, %2 : tensor<123x789xf32>, tensor<789x123xf32>) {
94+
^bb0(%in: f32, %out: f32, %out_0: f32):
95+
%5 = arith.addf %in, %out : f32
96+
%6 = arith.addf %5, %out_0 : f32
97+
linalg.yield %5, %6 : f32, f32
98+
} -> (tensor<123x789xf32>, tensor<789x123xf32>)
99+
// CHECK: return %[[RES]]#1, %[[RES]]#2
100+
return %4#0, %4#1 : tensor<123x789xf32>, tensor<789x123xf32>
101+
}
102+
103+

0 commit comments

Comments
 (0)