Skip to content

Commit 03e820f

Browse files
committed
[mlir][gpu] Patterns to promote gpu.shuffle to specialized AMDGPU ops
Only swizzle promotion for now, may add DPP ops support later.
1 parent e268f71 commit 03e820f

File tree

7 files changed

+124
-17
lines changed

7 files changed

+124
-17
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,24 @@ def MapNestedForallToThreads :
132132
TransformEachOpTrait,
133133
TransformOpInterface]> {
134134
let description = [{
135-
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
135+
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
136136
distributed `gpu.thread_id` attribute.
137137

138138
The operation searches for `scf.forall` ops nested under `target` and maps
139-
each such op to GPU threads.
140-
139+
each such op to GPU threads.
140+
141141
`scf.forall` induction variables are rewritten to `gpu.thread_id` according
142142
to the `mapping` attribute.
143143

144144
Different types of mappings attributes are supported:
145145
- the block_dims is a list of integers that specifies the number of
146146
threads in each dimension. This is a mandatory attribute that is used
147-
to constrain the number of threads in each dimension. If an
147+
to constrain the number of threads in each dimension. If an
148148
`scf.forall` op is mapped to fewer threads, predication occurs.
149149
- the warp_dims is a list of integers that specifies the number of
150150
warps in each dimension. This is an optional attribute that is used
151151
to constrain the number of warps in each dimension. When present, this
152-
attribute must be specified in a way that is compatible with the
152+
attribute must be specified in a way that is compatible with the
153153
block_dims attribute. If an `scf.forall` op is mapped to fewer warps,
154154
predication occurs.
155155

@@ -164,7 +164,7 @@ def MapNestedForallToThreads :
164164
inserted after each scf.forall op. At this time, this is an all or nothing
165165
choice. This will need to be tightened in the future.
166166

167-
The operation alters the block size of the given gpu_launch using the
167+
The operation alters the block size of the given gpu_launch using the
168168
mandatory block_dims argument.
169169

170170
#### Return modes:
@@ -268,7 +268,7 @@ def MapForallToBlocks :
268268
Only scf.forall distributed to **at most 3 dimensions** are
269269
currently supported.
270270

271-
The operation alters the block size of the given gpu_launch using the
271+
The operation alters the block size of the given gpu_launch using the
272272
grid_dims argument.
273273

274274
#### Return modes:
@@ -300,7 +300,7 @@ def MapForallToBlocks :
300300
`:` functional-type($target, $result)
301301
}];
302302
let hasVerifier = 1;
303-
303+
304304
let extraClassDeclaration = [{
305305
::mlir::DiagnosedSilenceableFailure applyToOne(
306306
::mlir::transform::TransformRewriter &rewriter,
@@ -310,4 +310,15 @@ def MapForallToBlocks :
310310
}];
311311
}
312312

313+
def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
314+
"apply_patterns.gpu.gpu_shuffle_to_amdgpu",
315+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
316+
let description = [{
317+
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
318+
AMDGPU intrinsics.
319+
}];
320+
let assemblyFormat = "attr-dict";
321+
}
322+
323+
313324
#endif // GPU_TRANSFORM_OPS

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
109109
/// Erase barriers that do not enforce conflicting memory side effects.
110110
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
111111

112+
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
113+
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
114+
112115
/// Generate the code for registering passes.
113116
#define GEN_PASS_REGISTRATION
114117
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
150150
rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
151151
Value dstLane;
152152
// TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
153-
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
154-
// perf.
155153
switch (op.getMode()) {
156154
case gpu::ShuffleMode::DOWN:
157155
dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ add_mlir_dialect_library(MLIRGPUTransforms
3737
Transforms/ModuleToBinary.cpp
3838
Transforms/NVVMAttachTarget.cpp
3939
Transforms/ParallelLoopMapper.cpp
40+
Transforms/PromoteShuffleToAMDGPU.cpp
4041
Transforms/ROCDLAttachTarget.cpp
41-
Transforms/ShuffleRewriter.cpp
4242
Transforms/SPIRVAttachTarget.cpp
43+
Transforms/ShuffleRewriter.cpp
4344
Transforms/SubgroupReduceLowering.cpp
4445

4546
OBJECT
@@ -52,8 +53,8 @@ add_mlir_dialect_library(MLIRGPUTransforms
5253
MLIRParallelLoopMapperEnumsGen
5354

5455
LINK_LIBS PUBLIC
55-
MLIRAffineUtils
5656
MLIRAMDGPUDialect
57+
MLIRAffineUtils
5758
MLIRArithDialect
5859
MLIRAsyncDialect
5960
MLIRBufferizationDialect
@@ -67,12 +68,12 @@ add_mlir_dialect_library(MLIRGPUTransforms
6768
MLIRMemRefDialect
6869
MLIRNVVMTarget
6970
MLIRPass
71+
MLIRROCDLDialect
72+
MLIRROCDLTarget
7073
MLIRSCFDialect
71-
MLIRSideEffectInterfaces
7274
MLIRSPIRVTarget
75+
MLIRSideEffectInterfaces
7376
MLIRSupport
74-
MLIRROCDLDialect
75-
MLIRROCDLTarget
7677
MLIRTransformUtils
7778
MLIRVectorDialect
7879
)

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1415
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -136,6 +137,11 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
136137
populateGpuRewritePatterns(patterns);
137138
}
138139

140+
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
141+
RewritePatternSet &patterns) {
142+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
143+
}
144+
139145
//===----------------------------------------------------------------------===//
140146
// ApplyUnrollVectorsSubgroupMmaOp
141147
//===----------------------------------------------------------------------===//
@@ -914,9 +920,10 @@ class GPUTransformDialectExtension
914920
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
915921

916922
GPUTransformDialectExtension() {
917-
declareGeneratedDialect<scf::SCFDialect>();
918-
declareGeneratedDialect<arith::ArithDialect>();
919923
declareGeneratedDialect<GPUDialect>();
924+
declareGeneratedDialect<amdgpu::AMDGPUDialect>();
925+
declareGeneratedDialect<arith::ArithDialect>();
926+
declareGeneratedDialect<scf::SCFDialect>();
920927
registerTransformOps<
921928
#define GET_OP_LIST
922929
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10+
// AMDGPU intrinsics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
25+
/// and offset must be a constant integer in the range [0, 31].
26+
struct PromoteShuffleToSwizzlePattern
27+
: public OpRewritePattern<gpu::ShuffleOp> {
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
31+
PatternRewriter &rewriter) const override {
32+
if (op.getMode() != gpu::ShuffleMode::XOR)
33+
return rewriter.notifyMatchFailure(op,
34+
"only xor shuffle mode is supported");
35+
36+
if (!isConstantIntValue(op.getWidth(), 64))
37+
return rewriter.notifyMatchFailure(op,
38+
"only 64 width shuffle is supported");
39+
40+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
41+
if (!offset)
42+
return rewriter.notifyMatchFailure(op,
43+
"offset must be a constant integer");
44+
45+
int64_t offsetValue = *offset;
46+
if (offsetValue < 0 || offsetValue >= 32)
47+
return rewriter.notifyMatchFailure(op,
48+
"offset must be in the range [0, 31]");
49+
50+
Location loc = op.getLoc();
51+
Value res = rewriter.create<amdgpu::SwizzleBitModeOp>(
52+
loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
53+
/*orMask=*/0, /*xorMask=*/offsetValue);
54+
Value valid = rewriter.create<arith::ConstantIntOp>(loc, 1, /*width*/ 1);
55+
rewriter.replaceOp(op, {res, valid});
56+
return success();
57+
}
58+
};
59+
} // namespace
60+
61+
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
62+
RewritePatternSet &patterns) {
63+
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
64+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
2+
3+
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
5+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %func {
7+
transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
8+
} : !transform.any_op
9+
transform.yield
10+
}
11+
}
12+
13+
// CHECK-LABEL: func @gpu_shuffle_swizzle
14+
// CHECK-SAME: (%[[ARG:.*]]: i32)
15+
func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
16+
// CHECK: %[[TRUE:.*]] = arith.constant true
17+
// CHECK: %[[RES:.*]] = amdgpu.swizzle_bitmode %[[ARG]] 31 0 23 : i32
18+
// CHECK: return %[[RES]], %[[TRUE]] : i32, i1
19+
%width = arith.constant 64 : i32
20+
%offset = arith.constant 23 : i32
21+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
22+
func.return %shfl, %pred : i32, i1
23+
}

0 commit comments

Comments
 (0)