Skip to content

Commit b537c5b

Browse files
committed
[mlir] Async: clone constants into async.execute functions and parallel compute functions
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D107007
1 parent 84602f9 commit b537c5b

File tree

8 files changed

+132
-2
lines changed

8 files changed

+132
-2
lines changed

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
190190

191191
ModuleOp module = op->getParentOfType<ModuleOp>();
192192

193+
// Make sure that all constants will be inside the parallel operation body to
194+
// reduce the number of parallel compute function arguments.
195+
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
196+
193197
ParallelComputeFunctionType computeFuncType =
194198
getParallelComputeFunctionType(op, rewriter);
195199

mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
235235
MLIRContext *ctx = module.getContext();
236236
Location loc = execute.getLoc();
237237

238+
// Make sure that all constants will be inside the outlined async function to
239+
// reduce the number of function arguments.
240+
cloneConstantsIntoTheRegion(execute.body());
241+
238242
// Collect all outlined function inputs.
239243
SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
240244
execute.dependencies().end());

mlir/lib/Dialect/Async/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
33
AsyncRuntimeRefCounting.cpp
44
AsyncRuntimeRefCountingOpt.cpp
55
AsyncToAsyncRuntime.cpp
6+
PassDetail.cpp
67

78
ADDITIONAL_HEADER_DIRS
89
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/IR/Builders.h"
11+
#include "mlir/Transforms/RegionUtils.h"
12+
13+
using namespace mlir;
14+
15+
void mlir::async::cloneConstantsIntoTheRegion(Region &region) {
16+
OpBuilder builder(&region);
17+
cloneConstantsIntoTheRegion(region, builder);
18+
}
19+
20+
void mlir::async::cloneConstantsIntoTheRegion(Region &region,
21+
OpBuilder &builder) {
22+
// Values implicitly captured by the region.
23+
llvm::SetVector<Value> captures;
24+
getUsedValuesDefinedAbove(region, region, captures);
25+
26+
OpBuilder::InsertionGuard guard(builder);
27+
builder.setInsertionPointToStart(&region.front());
28+
29+
// Clone ConstantLike operations into the region.
30+
for (Value capture : captures) {
31+
Operation *op = capture.getDefiningOp();
32+
if (!op || !op->hasTrait<OpTrait::ConstantLike>())
33+
continue;
34+
35+
Operation *cloned = builder.clone(*op);
36+
37+
for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
38+
Value orig = std::get<0>(tuple);
39+
Value replacement = std::get<1>(tuple);
40+
replaceAllUsesInRegionWith(orig, replacement, region);
41+
}
42+
}
43+
}

mlir/lib/Dialect/Async/Transforms/PassDetail.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ class SCFDialect;
2525
#define GEN_PASS_CLASSES
2626
#include "mlir/Dialect/Async/Passes.h.inc"
2727

28+
// -------------------------------------------------------------------------- //
29+
// Utility functions shared by Async Transformations.
30+
// -------------------------------------------------------------------------- //
31+
32+
// Forward declarations.
33+
class OpBuilder;
34+
35+
namespace async {
36+
37+
/// Clone ConstantLike operations that are defined above the given region and
38+
/// have users in the region into the region entry block. We do that to reduce
39+
/// the number of function arguments when we outline `async.execute` and
40+
/// `scf.parallel` operations body into functions.
41+
void cloneConstantsIntoTheRegion(Region &region);
42+
void cloneConstantsIntoTheRegion(Region &region, OpBuilder &builder);
43+
44+
} // namespace async
45+
2846
} // namespace mlir
2947

3048
#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_

mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,14 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
8989
}
9090

9191
// Function outlined from the inner async.execute operation.
92-
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
92+
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
9393
// CHECK-SAME: -> !llvm.ptr<i8>
9494
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
9595
// CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
9696
// CHECK: call @mlirAsyncRuntimeExecute
9797
// CHECK: llvm.intr.coro.suspend
98-
// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32>
98+
// CHECK: %[[C0:.*]] = constant 0 : index
99+
// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32>
99100
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
100101

101102
// Function outlined from the outer async.execute operation.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: -async-parallel-for=async-dispatch=true \
3+
// RUN: | FileCheck %s
4+
5+
// RUN: mlir-opt %s \
6+
// RUN: -async-parallel-for=async-dispatch=false \
7+
// RUN: -canonicalize -inline -symbol-dce \
8+
// RUN: | FileCheck %s
9+
10+
// Check that constants defined outside of the `scf.parallel` body will be
11+
// sunk into the parallel compute function to avoid blowing up the number
12+
// of parallel compute function arguments.
13+
14+
// CHECK-LABEL: func @clone_constant(
15+
func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
16+
%one = constant 1.0 : f32
17+
18+
scf.parallel (%i) = (%lb) to (%ub) step (%st) {
19+
memref.store %one, %arg0[%i] : memref<?xf32>
20+
}
21+
22+
return
23+
}
24+
25+
// CHECK-LABEL: func private @parallel_compute_fn(
26+
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
27+
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
28+
// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index,
29+
// CHECK-SAME: %[[LB:arg[0-9]+]]: index,
30+
// CHECK-SAME: %[[UB:arg[0-9]+]]: index,
31+
// CHECK-SAME: %[[STEP:arg[0-9]+]]: index,
32+
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?xf32>
33+
// CHECK-SAME: ) {
34+
// CHECK: %[[CST:.*]] = constant 1.0{{.*}} : f32
35+
// CHECK: scf.for
36+
// CHECK: memref.store %[[CST]], %[[MEMREF]]

mlir/test/Dialect/Async/async-to-async-runtime.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,26 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
406406
// Check that structured control flow lowered to CFG.
407407
// CHECK-NOT: scf.if
408408
// CHECK: cond_br %[[FLAG]]
409+
410+
// -----
411+
// Constants captured by the async.execute region should be cloned into the
412+
// outline async execute function.
413+
414+
// CHECK-LABEL: @clone_constants
415+
func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
416+
%c0 = constant 0 : index
417+
%token = async.execute {
418+
memref.store %arg0, %arg1[%c0] : memref<1xf32>
419+
async.yield
420+
}
421+
async.await %token : !async.token
422+
return
423+
}
424+
425+
// Function outlined from the async.execute operation.
426+
// CHECK-LABEL: func private @async_execute_fn(
427+
// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
428+
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
429+
// CHECK-SAME: ) -> !async.token
430+
// CHECK: %[[CST:.*]] = constant 0 : index
431+
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]

0 commit comments

Comments
 (0)