Skip to content

Commit a1eaed7

Browse files
authored
[mlir][gpu] Fix GPU YieldOP format and traits (#78006)
This patch adds assembly format to `gpu::YieldOp`. It also adds the return like trait, to make it compatible with `RegionBranchOpInterface`.
1 parent 2e0a105 commit a1eaed7

File tree

4 files changed

+17
-1
lines changed

4 files changed

+17
-1
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/OpDefinition.h"
2424
#include "mlir/IR/OpImplementation.h"
2525
#include "mlir/IR/SymbolTable.h"
26+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2627
#include "mlir/Interfaces/FunctionInterfaces.h"
2728
#include "mlir/Interfaces/InferIntRangeInterface.h"
2829
#include "mlir/Interfaces/InferTypeOpInterface.h"

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
2222
include "mlir/IR/CommonTypeConstraints.td"
2323
include "mlir/IR/EnumAttr.td"
2424
include "mlir/IR/SymbolInterfaces.td"
25+
include "mlir/Interfaces/ControlFlowInterfaces.td"
2526
include "mlir/Interfaces/DataLayoutInterfaces.td"
2627
include "mlir/Interfaces/FunctionInterfaces.td"
2728
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -961,7 +962,7 @@ def GPU_TerminatorOp : GPU_Op<"terminator", [HasParent<"LaunchOp">,
961962
let assemblyFormat = "attr-dict";
962963
}
963964

964-
def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
965+
def GPU_YieldOp : GPU_Op<"yield", [Pure, ReturnLike, Terminator]>,
965966
Arguments<(ins Variadic<AnyType>:$values)> {
966967
let summary = "GPU yield operation";
967968
let description = [{
@@ -974,6 +975,8 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
974975
gpu.yield %f0, %f1 : f32, f32
975976
```
976977
}];
978+
979+
let assemblyFormat = "attr-dict ($values^ `:` type($values))?";
977980
}
978981

979982
// These mirror the reduction combining kinds from the vector dialect.

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_mlir_dialect_library(MLIRGPUDialect
3737
LINK_LIBS PUBLIC
3838
MLIRArithDialect
3939
MLIRDLTIDialect
40+
MLIRControlFlowInterfaces
4041
MLIRFunctionInterfaces
4142
MLIRInferIntRangeInterface
4243
MLIRIR

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ module attributes {gpu.container_module} {
9494
// CHECK-NEXT: } : (f32) -> f32
9595
%sum1 = gpu.all_reduce add %one uniform {} : (f32) -> f32
9696

97+
// CHECK: %{{.*}} = gpu.all_reduce %{{.*}} {
98+
// CHECK-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
99+
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
100+
// CHECK-NEXT: gpu.yield %{{.*}} : f32
101+
// CHECK-NEXT: } : (f32) -> f32
102+
%sum2 = gpu.all_reduce %one {
103+
^bb(%lhs : f32, %rhs : f32):
104+
%tmp = arith.addf %lhs, %rhs : f32
105+
gpu.yield %tmp : f32
106+
} : (f32) -> (f32)
107+
97108
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
98109
%sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
99110

0 commit comments

Comments
 (0)