@@ -47,6 +47,15 @@ static bool hasSideEffects(Operation *op) {
47
47
struct GpuAsyncRegionPass ::ThreadTokenCallback {
48
48
ThreadTokenCallback (MLIRContext &context) : builder(&context) {}
49
49
50
+ WalkResult operator ()(Block *block) {
51
+ for (Operation &op : make_early_inc_range (*block)) {
52
+ if (failed (visit (&op)))
53
+ return WalkResult::interrupt ();
54
+ }
55
+ return WalkResult::advance ();
56
+ }
57
+
58
+ private:
50
59
// If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
51
60
// create a current token (unless it already exists), and 'thread' that token
52
61
// through the `op` so that it executes asynchronously.
@@ -55,11 +64,15 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
55
64
// host-synchronize execution. A `!gpu.async.token` will therefore only be
56
65
// used inside of its block and GPU execution will always synchronize with
57
66
// the host at block boundaries.
58
- WalkResult operator () (Operation *op) {
67
+ LogicalResult visit (Operation *op) {
59
68
if (isa<gpu::LaunchOp>(op))
60
69
return op->emitOpError (" replace with gpu.launch_func first" );
61
- if (isa<gpu::WaitOp>(op))
62
- return op->emitOpError (" unexpected pre-existing gpu.wait" );
70
+ if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
71
+ if (currentToken)
72
+ waitOp.addAsyncDependency (currentToken);
73
+ currentToken = waitOp.asyncToken ();
74
+ return success ();
75
+ }
63
76
builder.setInsertionPoint (op);
64
77
if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
65
78
return rewriteAsyncOp (asyncOp); // Replace GPU op with async version.
@@ -71,14 +84,9 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
71
84
return success ();
72
85
}
73
86
74
- private:
75
87
// Replaces asyncOp with a clone that returns a token.
76
88
LogicalResult rewriteAsyncOp (gpu::AsyncOpInterface asyncOp) {
77
89
auto *op = asyncOp.getOperation ();
78
- if (asyncOp.getAsyncToken ())
79
- // TODO: Support ops that are already async.
80
- return op->emitOpError (" is already async" );
81
-
82
90
auto tokenType = builder.getType <gpu::AsyncTokenType>();
83
91
84
92
// If there is no current token, insert a `gpu.wait async` without
@@ -87,6 +95,11 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
87
95
currentToken = createWaitOp (op->getLoc (), tokenType, {});
88
96
asyncOp.addAsyncDependency (currentToken);
89
97
98
+ // Return early if op returns a token already.
99
+ currentToken = asyncOp.getAsyncToken ();
100
+ if (currentToken)
101
+ return success ();
102
+
90
103
// Clone the op to return a token in addition to the other results.
91
104
SmallVector<Type, 1 > resultTypes;
92
105
resultTypes.reserve (1 + op->getNumResults ());
@@ -315,10 +328,7 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
315
328
// inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
316
329
// execution semantics and that no GPU ops are asynchronous yet.
317
330
void GpuAsyncRegionPass::runOnFunction () {
318
- if (getFunction ()
319
- .getRegion ()
320
- .walk (ThreadTokenCallback (getContext ()))
321
- .wasInterrupted ())
331
+ if (getFunction ()->walk (ThreadTokenCallback (getContext ())).wasInterrupted ())
322
332
return signalPassFailure ();
323
333
324
334
// Collect gpu.wait ops that we can move out of async.execute regions.
0 commit comments