Skip to content

Commit f3eb18b

Browse files
committed
support reduce and multi-consumers
1 parent 639a90b commit f3eb18b

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

lib/gc/Transforms/AnyTilableFusion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,16 @@ verifyTilableOpTileSizesOnAffineMap(RewriterBase &rewriter, Operation *op,
8888
unsigned iterPosition =
8989
cast<AffineDimExpr>(resultExpr.value()).getPosition();
9090
if (iterTypes[iterPosition] == utils::IteratorType::reduction) {
91-
if (iterDomain[iterPosition].size != tileSizes[resultExpr.index()])
91+
std::optional<int64_t> cstIterDomain =
92+
getConstantIntValue(iterDomain[iterPosition].size);
93+
FailureOr<int64_t> cstTileSizes =
94+
ValueBoundsConstraintSet::computeConstantBound(
95+
presburger::BoundType::UB, tileSizes[resultExpr.index()], nullptr,
96+
true);
97+
if (!cstIterDomain || failed(cstTileSizes) ||
98+
cstIterDomain != cstTileSizes) {
9299
return failure();
100+
}
93101
}
94102
}
95103
return success();
@@ -436,7 +444,6 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
436444
if (failed(consAnchorList))
437445
return tiledConsumerList;
438446

439-
// TODO: sorted by userList and position in parentBlock
440447
for (auto &consAnchor : *consAnchorList) {
441448
if (alreadyTiledOps.count(consAnchor.getFusableOp()))
442449
continue;
@@ -450,7 +457,7 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
450457
scfX::tileAndFuseConsumerOfSlice(rewriter, *candidateSliceOp);
451458
if (fusedResult) {
452459
tiledConsumerList.push_back(fusedResult.value().tiledOps[0]);
453-
rewriter.eraseOp(consAnchor.getFusableOp());
460+
rewriter.eraseOp(fusedResult.value().origConsumerOperand->getOwner());
454461
}
455462
}
456463

test/gc/Transform/any-tilable-fusion.mlir

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: gc-opt --split-input-file -any-tilable-fusion %s
22

3-
func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> {
3+
module {
4+
func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> {
45
%c32 = arith.constant 32 : index
56
%c512 = arith.constant 512 : index
67
%c128 = arith.constant 128 : index
@@ -58,4 +59,83 @@ func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg
5859
%3 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
5960
%4 = linalg.exp ins(%3 : tensor<128x256xbf16>) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
6061
return %4 : tensor<128x256xbf16>
61-
}
62+
}
63+
}
64+
65+
// -----
66+
67+
#map = affine_map<(d0) -> (d0 * 128)>
68+
module {
69+
func.func @fuse_multiple_consumer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>, %arg3: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
70+
%c0 = arith.constant 0 : index
71+
%c64 = arith.constant 64 : index
72+
%c128 = arith.constant 128 : index
73+
%cst = arith.constant 0.000000e+00 : f32
74+
%dest0 = tensor.empty() : tensor<256x256xf32>
75+
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
76+
%1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest1) -> tensor<256x256xf32> {
77+
%iv0 = affine.apply #map(%arg4)
78+
%iv1 = affine.apply #map(%arg5)
79+
%extracted_slice_1 = tensor.extract_slice %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
80+
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
81+
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
82+
%2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice_1) -> (tensor<128x128xf32>) {
83+
%3 = scf.for %arg9 = %c0 to %c128 step %c64 iter_args(%arg10 = %arg8) -> (tensor<128x128xf32>) {
84+
%extracted_slice_4 = tensor.extract_slice %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
85+
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
86+
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg9] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
87+
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
88+
%insert_slice = tensor.insert_slice %4 into %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
89+
scf.yield %insert_slice : tensor<128x128xf32>
90+
}
91+
scf.yield %3 : tensor<128x128xf32>
92+
}
93+
scf.forall.in_parallel {
94+
tensor.parallel_insert_slice %2 into %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
95+
}
96+
}
97+
%5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
98+
%6 = linalg.add ins(%1, %arg3 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
99+
return %5, %6 : tensor<256x256xf32>, tensor<256x256xf32>
100+
}
101+
}
102+
103+
// -----
104+
105+
#map = affine_map<(d0) -> (d0 * 128)>
106+
module {
107+
func.func @fuse_reduce(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256xf32> {
108+
%c0 = arith.constant 0 : index
109+
%c64 = arith.constant 64 : index
110+
%c128 = arith.constant 128 : index
111+
%c256 = arith.constant 256 : index
112+
%cst = arith.constant 0.000000e+00 : f32
113+
%dest0 = tensor.empty() : tensor<256x256xf32>
114+
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
115+
%1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
116+
%iv0 = affine.apply #map(%arg3)
117+
%iv1 = affine.apply #map(%arg4)
118+
%extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 256] [1, 1] : tensor<256x256xf32> to tensor<128x256xf32>
119+
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
120+
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 256] [1, 1] : tensor<512x256xf32> to tensor<512x256xf32>
121+
%2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x256xf32>) {
122+
%3 = scf.for %arg8 = %c0 to %c256 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x256xf32>) {
123+
%extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x256xf32> to tensor<64x64xf32>
124+
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
125+
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
126+
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
127+
%insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x256xf32>
128+
scf.yield %insert_slice : tensor<128x256xf32>
129+
}
130+
scf.yield %3 : tensor<128x256xf32>
131+
}
132+
scf.forall.in_parallel {
133+
tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x256xf32>
134+
}
135+
}
136+
%5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
137+
%dest2 = tensor.empty() : tensor<256xf32>
138+
%6 = linalg.reduce { arith.addf } ins(%5 : tensor<256x256xf32>) outs(%dest2 : tensor<256xf32>) dimensions = [1]
139+
return %6 : tensor<256xf32>
140+
}
141+
}

0 commit comments

Comments
 (0)