Skip to content

Commit f7dd9d8

Browse files
committed
extend fuse producer to multi-level extractSliceOp
1 parent 6e45fa9 commit f7dd9d8

File tree

5 files changed

+303
-5
lines changed

5 files changed

+303
-5
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ struct SCFFuseProducerOfSliceResult {
157157
Value tiledAndFusedProducer; // Tile and fused producer value.
158158
SmallVector<Operation *> tiledOps;
159159
};
160+
160161
std::optional<SCFFuseProducerOfSliceResult>
161162
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
162163
tensor::ExtractSliceOp candidateSliceOp,
163164
MutableArrayRef<LoopLikeOpInterface> loops);
164165

166+
std::optional<SCFFuseProducerOfSliceResult>
167+
tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
168+
165169
/// Reconstruct the fused producer from within the tiled-and-fused code. Based
166170
/// on the slice of the producer computed in place it is possible that within
167171
/// the loop nest same slice of the producer is computed multiple times. It is

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,12 +1068,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
10681068
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
10691069
}
10701070

1071-
/// Implementation of fusing producer of a single slice by computing the
1071+
/// Basic implementation of fusing producer of a single slice by computing the
10721072
/// slice of the producer in-place.
1073-
std::optional<scf::SCFFuseProducerOfSliceResult>
1074-
mlir::scf::tileAndFuseProducerOfSlice(
1075-
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1076-
MutableArrayRef<LoopLikeOpInterface> loops) {
1073+
static std::optional<scf::SCFFuseProducerOfSliceResult>
1074+
tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
1075+
tensor::ExtractSliceOp candidateSliceOp,
1076+
MutableArrayRef<LoopLikeOpInterface> loops) {
10771077
// 1. Get the producer of the source (potentially walking through
10781078
// `iter_args` of nested `scf.for`)
10791079
auto [fusableProducer, destinationInitArg] =
@@ -1185,6 +1185,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
11851185
tileAndFuseResult->tiledOps};
11861186
}
11871187

1188+
/// Get the real producer from candidate ExtractSliceOp
1189+
///
1190+
/// ```
1191+
/// %0 = producer
1192+
/// %1 = scf.for(%arg1 = %0)
1193+
/// %2 = extract %arg1
1194+
/// %3 = scf.for(%arg2 = %2)
1195+
/// %4 = extract %args2
1196+
/// ...
1197+
/// ```
1198+
///
1199+
/// @param candidateSliceOp: %4 = extract %args2
1200+
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
1201+
/// @return OpResult Producer : %0 = producer
1202+
static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
1203+
Operation *candidateSliceOp,
1204+
SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
1205+
int maxDepth = 5) {
1206+
if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
1207+
return failure();
1208+
// control recursive time in avoid of stack overflow
1209+
if (curDepth > maxDepth)
1210+
return failure();
1211+
1212+
auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
1213+
backwardSlice.push_back(extractOp);
1214+
Value rootSource = extractOp.getSourceMutable().get();
1215+
1216+
while (true) {
1217+
if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
1218+
if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
1219+
iterArg.getOwner()->getParentOp())) {
1220+
rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
1221+
continue;
1222+
}
1223+
return failure();
1224+
} else if (auto sliceOp =
1225+
rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
1226+
// walk up loop to find larger candidate extractSliceOp
1227+
return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
1228+
curDepth + 1);
1229+
}
1230+
break;
1231+
}
1232+
return dyn_cast<OpResult>(rootSource);
1233+
}
1234+
1235+
/// Recursively find the outer nest loops of given loop(included) while the
1236+
/// predict function succeed, sorted from outer to inner.
1237+
///
1238+
/// @param loop: target loop, note that this loop will be also included. I.e.
1239+
/// if no other nest loops were found, just return itself.
1240+
/// @param pred: predict function, the termination condition of recursive
1241+
/// process.
1242+
/// @return Outer Nest Loops: nest loops outside given target loop(included).
1243+
///
1244+
/// E.g.
1245+
///
1246+
/// ```
1247+
/// %0 = scf.for()
1248+
/// %1 = scf.for()
1249+
/// %2 = scf.for()
1250+
/// ```
1251+
///
1252+
/// If `%2 = scf.for` is given without specific prediction function, this
1253+
/// function will return three nest loops: %0 + %1 + %2.
1254+
static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1255+
LoopLikeOpInterface loop,
1256+
const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1257+
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1258+
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
1259+
while (outerLoop && succeeded(pred(outerLoop))) {
1260+
nestLoops.push_back(outerLoop);
1261+
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
1262+
}
1263+
// sorted from outer to inner
1264+
return {nestLoops.rbegin(), nestLoops.rend()};
1265+
}
1266+
1267+
/// Enhanced version for basic implementation of fusing producer, which can deal
1268+
/// with multi-level candidates. E.g.
1269+
///
1270+
/// ```
1271+
/// %0 = untiled_producer
1272+
/// %1 = scf.for(%arg1 = %0)
1273+
/// %2 = tensor.extract_slice %arg1
1274+
/// %3 = scf.for(%arg2 = %2)
1275+
/// %4 = tensor.extract_slice %args2
1276+
/// %5 = tiled_consumer ins(%4)
1277+
/// ```
1278+
///
1279+
/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
1280+
/// inner loop `%3 = scf.for`.
1281+
std::optional<scf::SCFFuseProducerOfSliceResult>
1282+
mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
1283+
Operation *candidateSliceOp) {
1284+
SmallVector<tensor::ExtractSliceOp> backwardSlice;
1285+
if (failed(
1286+
getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
1287+
return std::nullopt;
1288+
}
1289+
1290+
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1291+
// reverse from outer to inner
1292+
std::reverse(backwardSlice.begin(), backwardSlice.end());
1293+
// multiple application of `tileAndFuseProducerOfSliceImpl`
1294+
for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
1295+
// get nest loops between next candidate sliceOp and tiled producer.
1296+
auto whileProducerOutOfLoopBlock =
1297+
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1298+
if (fuseProducerResult) {
1299+
Block &body = loop->getRegion(0).front();
1300+
if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
1301+
->getBlock() == &body)
1302+
return failure();
1303+
}
1304+
return success();
1305+
};
1306+
SmallVector<LoopLikeOpInterface> outerLoops =
1307+
getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),
1308+
whileProducerOutOfLoopBlock);
1309+
fuseProducerResult =
1310+
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
1311+
if (!fuseProducerResult) {
1312+
return std::nullopt;
1313+
}
1314+
}
1315+
return fuseProducerResult;
1316+
}
1317+
1318+
/// Implementation of fusing producer of a single slice by computing the
1319+
/// slice of the producer in-place.
1320+
std::optional<scf::SCFFuseProducerOfSliceResult>
1321+
mlir::scf::tileAndFuseProducerOfSlice(
1322+
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1323+
MutableArrayRef<LoopLikeOpInterface> loops) {
1324+
return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
1325+
}
1326+
11881327
/// Reconstruct the fused producer from within the tiled-and-fused code.
11891328
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
11901329
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
2+
3+
#map = affine_map<(d0) -> (d0 * 128)>
4+
module {
5+
func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
6+
%c0 = arith.constant 0 : index
7+
%c64 = arith.constant 64 : index
8+
%c128 = arith.constant 128 : index
9+
%cst = arith.constant 0.000000e+00 : f32
10+
%dest0 = tensor.empty() : tensor<256x256xf32>
11+
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
12+
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
13+
%iv0 = affine.apply #map(%arg3)
14+
%iv1 = affine.apply #map(%arg4)
15+
%extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
16+
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
17+
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
18+
%2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
19+
%3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
20+
%extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
21+
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
22+
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
23+
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
24+
%insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
25+
scf.yield %insert_slice : tensor<128x128xf32>
26+
}
27+
scf.yield %3 : tensor<128x128xf32>
28+
}
29+
scf.forall.in_parallel {
30+
tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
31+
}
32+
}
33+
return %1 : tensor<256x256xf32>
34+
}
35+
}
36+
37+
module attributes {transform.with_named_sequence} {
38+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
39+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
40+
: (!transform.any_op) -> !transform.any_op
41+
%yield = transform.get_producer_of_operand %matmul[2]
42+
: (!transform.any_op) -> !transform.any_op
43+
%a, %b = transform.test.fuse_producer %yield
44+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
45+
transform.yield
46+
}
47+
}
48+
49+
// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
50+
// CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
51+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
52+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
53+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
54+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
55+
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
56+
// CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
57+
// CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
58+
// CHECK-SAME: {
59+
// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
60+
// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
61+
// CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
62+
// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
63+
// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
64+
// CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
65+
// CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
66+
// CHECK-SAME: {
67+
// CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
68+
// CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
69+
// CHECK-SAME: {
70+
// CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
71+
// CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill
72+
// CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] :
73+
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
74+
// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
75+
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
76+
// CHECK-SAME: outs(%[[TILED_FILL_OUT]] :
77+
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
78+
// CHECK: scf.yield %[[INSERT_MAT]] :
79+
// CHECK: }
80+
// CHECK: scf.yield %[[LOOP_RESULT2]] :
81+
// CHECK: }
82+
// CHECK: scf.forall.in_parallel {
83+
// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
84+
// CHECK: }
85+
// CHECK: }
86+
// CHECK: return %[[FORALL_RESULT]] :

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
160160
: DiagnosedSilenceableFailure::success();
161161
}
162162

163+
//===----------------------------------------------------------------------===//
164+
// TestFuseProducerOp
165+
//===----------------------------------------------------------------------===//
166+
167+
/// Apply fusing of producer transformation to all payload ops and store both
168+
/// the original producer operation as well as the fused producer operation.
169+
template <typename Range>
170+
static LogicalResult
171+
applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
172+
Range &&payloadOps, TransformResults &transformResults) {
173+
SmallVector<Operation *> originalProducerOps;
174+
SmallVector<Operation *> fusedProducerOps;
175+
176+
for (Operation *target : payloadOps) {
177+
rewriter.setInsertionPoint(target);
178+
179+
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
180+
scf::tileAndFuseProducerOfSlice(rewriter, target);
181+
182+
if (!fuseProducerResults)
183+
return failure();
184+
185+
// Report back the relevant handles to the transform op.
186+
originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
187+
fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
188+
}
189+
190+
transformResults.set(transformOp->getOpResult(0), originalProducerOps);
191+
transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
192+
return success();
193+
}
194+
195+
DiagnosedSilenceableFailure
196+
transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
197+
TransformResults &transformResults,
198+
TransformState &state) {
199+
LogicalResult result =
200+
applyFuseProducer(rewriter, getOperation(),
201+
state.getPayloadOps(getTarget()), transformResults);
202+
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
203+
: DiagnosedSilenceableFailure::success();
204+
}
205+
206+
void transform::TestFuseProducerOp::getEffects(
207+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
208+
consumesHandle(getTargetMutable(), effects);
209+
producesHandle(getOperation()->getOpResults(), effects);
210+
modifiesPayload(effects);
211+
}
212+
163213
//===----------------------------------------------------------------------===//
164214
// TestFuseConsumerOp
165215
//===----------------------------------------------------------------------===//

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
4949
}];
5050
}
5151

52+
def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
53+
[DeclareOpInterfaceMethods<TransformOpInterface>,
54+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
55+
ReportTrackingListenerFailuresOpTrait]> {
56+
let description = [{
57+
Fuses the producer of the operation pointed to by the target handle
58+
using the options provided as attributes.
59+
}];
60+
61+
let arguments =
62+
(ins TransformHandleTypeInterface:$target);
63+
let results = (outs TransformHandleTypeInterface:$producer,
64+
TransformHandleTypeInterface:$fused_producer);
65+
66+
let assemblyFormat = [{
67+
$target attr-dict `:` functional-type(operands, results)
68+
}];
69+
}
70+
5271
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
5372
[DeclareOpInterfaceMethods<TransformOpInterface>,
5473
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

0 commit comments

Comments
 (0)