Skip to content

Commit 8b4a5e7

Browse files
authored
[Transform][Fusion] loose isTiledOpInLoop check to cover no insert_slice case (#311)
Sometimes, the `insert_slice` maybe eliminated due to the same size of `SRC` and `DEST`.
1 parent 686c34c commit 8b4a5e7

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
529529
return false;
530530
unsigned index = std::distance(uses.begin(), iter);
531531
SmallVector<unsigned> indices =
532-
llvm::to_vector(llvm::seq<unsigned>(0, numberUses));
532+
llvm::to_vector(llvm::seq<unsigned>(numberUses));
533533
indices.push_back(indices[index]);
534534
indices.erase(indices.begin() + index);
535535
operand->get().shuffleUseList(indices);
@@ -636,12 +636,9 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
636636
return failure();
637637

638638
// 3. check whether has either extract or insert slice op
639-
auto walkResult = forOp->walk(
640-
[](tensor::ExtractSliceOp) { return WalkResult::interrupt(); });
641-
if (!walkResult.wasInterrupted())
642-
return failure();
643-
walkResult = forOp->walk([](OffsetSizeAndStrideOpInterface op) {
644-
return isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(op)
639+
auto walkResult = forOp->walk([](OffsetSizeAndStrideOpInterface op) {
640+
return isa<tensor::ExtractSliceOp, tensor::InsertSliceOp,
641+
tensor::ParallelInsertSliceOp>(op)
645642
? WalkResult::interrupt()
646643
: WalkResult::advance();
647644
});

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,83 @@ module {
9595

9696
// -----
9797

98+
#map = affine_map<(d0) -> (d0 * 32)>
99+
#map1 = affine_map<(d0) -> (d0 * 16)>
100+
module {
101+
/// CHECK-LABEL: @fuse_mlp_vnni
102+
func.func @fuse_mlp_vnni(%arg0: tensor<128x1024xbf16>, %arg1: tensor<1024x512xbf16>, %arg2: tensor<512xbf16>) -> tensor<128x512xbf16> attributes {llvm.emit_c_interface} {
103+
%c2 = arith.constant 2 : index
104+
%c64 = arith.constant 64 : index
105+
%c0 = arith.constant 0 : index
106+
%cst = arith.constant dense<0.000000e+00> : tensor<128x512xbf16>
107+
/// CHECK: tensor.empty
108+
%0 = tensor.empty() : tensor<128x512xbf16>
109+
/// CHECK: tensor.empty
110+
%1 = tensor.empty() : tensor<16x64x16x32xbf16>
111+
%pack = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %1 : tensor<1024x512xbf16> -> tensor<16x64x16x32xbf16>
112+
/// CHECK: tensor.empty
113+
%2 = tensor.empty() : tensor<16x64x8x32x2xbf16>
114+
%pack_0 = tensor.pack %pack inner_dims_pos = [2] inner_tiles = [2] into %2 : tensor<16x64x16x32xbf16> -> tensor<16x64x8x32x2xbf16>
115+
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (16)
116+
%3 = scf.forall (%arg3) in (16) shared_outs(%arg4 = %0) -> (tensor<128x512xbf16>) {
117+
%9 = affine.apply #map(%arg3)
118+
%extracted_slice = tensor.extract_slice %arg4[0, %9] [128, 32] [1, 1] : tensor<128x512xbf16> to tensor<128x32xbf16>
119+
/// CHECK: tensor.empty
120+
%10 = tensor.empty() : tensor<128x32xf32>
121+
/// CHECK: linalg.copy
122+
%11 = linalg.copy ins(%extracted_slice : tensor<128x32xbf16>) outs(%10 : tensor<128x32xf32>) -> tensor<128x32xf32>
123+
/// CHECK: %[[TMP_RESULT:.*]]:2 = scf.for
124+
%12:2 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %11, %arg7 = %extracted_slice) -> (tensor<128x32xf32>, tensor<128x32xbf16>) {
125+
%14 = affine.apply #map1(%arg5)
126+
%extracted_slice_1 = tensor.extract_slice %arg0[0, %14] [128, 32] [1, 1] : tensor<128x1024xbf16> to tensor<128x32xbf16>
127+
/// CHECK: %[[PACK_OUT:.*]] = tensor.pack
128+
/// CHECK: %[[PACK_OUT_VNNI:.*]] = tensor.pack %[[PACK_OUT]]
129+
%extracted_slice_2 = tensor.extract_slice %pack_0[%arg3, %arg5, 0, 0, 0] [1, 2, 8, 32, 2] [1, 1, 1, 1, 1] : tensor<16x64x8x32x2xbf16> to tensor<1x2x8x32x2xbf16>
130+
/// CHECK: %[[COLLAPSE_OUT:.*]] = tensor.collapse_shape %[[PACK_OUT_VNNI]]
131+
%collapsed = tensor.collapse_shape %extracted_slice_2 [[0, 1], [2], [3], [4]] : tensor<1x2x8x32x2xbf16> into tensor<2x8x32x2xbf16>
132+
/// CHECK: %[[EXPAND_OUT:.*]] = tensor.expand_shape
133+
%expanded = tensor.expand_shape %extracted_slice_1 [[0], [1, 2]] output_shape [128, 2, 16] : tensor<128x32xbf16> into tensor<128x2x16xbf16>
134+
%15 = tensor.empty() : tensor<2x128x16xbf16>
135+
/// CHECK: %[[TRANSPOSE_OUT:.*]] = linalg.transpose ins(%[[EXPAND_OUT]] :
136+
%transposed = linalg.transpose ins(%expanded : tensor<128x2x16xbf16>) outs(%15 : tensor<2x128x16xbf16>) permutation = [1, 0, 2]
137+
/// CHECK: %[[MATMUL_OUT:.*]] = linalgx.batch_reduce_matmul_vnni ins(%[[TRANSPOSE_OUT]], %[[COLLAPSE_OUT]] :
138+
%16 = linalgx.batch_reduce_matmul_vnni ins(%transposed, %collapsed : tensor<2x128x16xbf16>, tensor<2x8x32x2xbf16>) outs(%arg6 : tensor<128x32xf32>) -> tensor<128x32xf32>
139+
%17 = arith.addi %arg5, %c2 : index
140+
%18 = arith.cmpi sge, %17, %c64 : index
141+
/// CHECK: %[[IF_RESULT:.*]] = scf.if
142+
%19 = scf.if %18 -> (tensor<128x32xbf16>) {
143+
%20 = linalg.copy ins(%16 : tensor<128x32xf32>) outs(%arg7 : tensor<128x32xbf16>) -> tensor<128x32xbf16>
144+
scf.yield %20 : tensor<128x32xbf16>
145+
} else {
146+
scf.yield %arg7 : tensor<128x32xbf16>
147+
}
148+
/// CHECK: scf.yield %[[MATMUL_OUT]], %[[IF_RESULT]] :
149+
scf.yield %16, %19 : tensor<128x32xf32>, tensor<128x32xbf16>
150+
}
151+
/// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast
152+
/// CHECK: %[[ADD_OUT:.*]] = linalg.add ins(%[[TMP_RESULT]]#1, %[[BROADCAST_OUT]] :
153+
/// CHECK: %[[MAX_OUT:.*]] = linalg.max ins(%[[ADD_OUT]],
154+
%13 = affine.apply #map(%arg3)
155+
scf.forall.in_parallel {
156+
/// CHECK: tensor.parallel_insert_slice %[[MAX_OUT]]
157+
/// CHECK: tensor.parallel_insert_slice
158+
/// CHECK: tensor.parallel_insert_slice
159+
tensor.parallel_insert_slice %12#1 into %arg4[0, %13] [128, 32] [1, 1] : tensor<128x32xbf16> into tensor<128x512xbf16>
160+
}
161+
}
162+
%4 = tensor.empty() : tensor<128x512xbf16>
163+
%broadcasted = linalg.broadcast ins(%arg2 : tensor<512xbf16>) outs(%4 : tensor<128x512xbf16>) dimensions = [0]
164+
%5 = tensor.empty() : tensor<128x512xbf16>
165+
%6 = linalg.add ins(%3, %broadcasted : tensor<128x512xbf16>, tensor<128x512xbf16>) outs(%5 : tensor<128x512xbf16>) -> tensor<128x512xbf16>
166+
%7 = tensor.empty() : tensor<128x512xbf16>
167+
%8 = linalg.max ins(%6, %cst : tensor<128x512xbf16>, tensor<128x512xbf16>) outs(%7 : tensor<128x512xbf16>) -> tensor<128x512xbf16>
168+
/// CHECK: return %[[FINAL_RESULT]]#2
169+
return %8 : tensor<128x512xbf16>
170+
}
171+
}
172+
173+
// -----
174+
98175
#map = affine_map<(d0) -> (d0 * 128)>
99176
module {
100177
/// CHECK-LABEL: @fuse_multiple_consumer

0 commit comments

Comments
 (0)