Skip to content

Commit e7d7250

Browse files
authored
[GEN] Implement GPU to GEN lowering (#13427)
Implements the lowering of the GPU dialect to the GEN dialect where possible. Currently there are only 6 GEN operations, so the lowering is a bit limited. Signed-off-by: Finlay Marno [email protected] --------- Signed-off-by: Finlay Marno <[email protected]>
1 parent 25f6fcb commit e7d7250

File tree

7 files changed

+349
-0
lines changed

7 files changed

+349
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- GPUToGEN.h - GPU to GEN Passes ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides passes to convert GPU dialect to GEN dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUTOGEN_GPUTOGEN_H
14+
#define MLIR_CONVERSION_GPUTOGEN_GPUTOGEN_H
15+
16+
#include <memory>
17+
18+
namespace mlir {
19+
20+
class Pass;
21+
class RewritePatternSet;
22+
23+
#define GEN_PASS_DECL_CONVERTGPUOPSTOGENOPS
24+
#include "mlir/Conversion/Passes.h.inc"
25+
26+
void populateGPUToGENPatterns(RewritePatternSet &patterns);
27+
28+
} // namespace mlir
29+
#endif // MLIR_CONVERSION_GPUTOGEN_GPUTOGEN_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/Conversion/GENToLLVM/GENToLLVM.h"
3737
#include "mlir/Conversion/GENToSPIRV/GENToSPIRV.h"
3838
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
39+
#include "mlir/Conversion/GPUToGEN/GPUToGEN.h"
3940
#include "mlir/Conversion/GPUToGENX/GPUToGENXPass.h"
4041
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
4142
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,15 @@ def LowerHostCodeToLLVMPass : Pass<"lower-host-to-llvm", "ModuleOp"> {
540540
let dependentDialects = ["LLVM::LLVMDialect"];
541541
}
542542

543+
//===----------------------------------------------------------------------===//
544+
// GPUToGEN
545+
//===----------------------------------------------------------------------===//
546+
547+
def ConvertGpuOpsToGENOps : Pass<"convert-gpu-to-gen"> {
548+
let summary = "Generate GEN operations for gpu operations";
549+
let dependentDialects = ["GEN::GENDialect"];
550+
}
551+
543552
//===----------------------------------------------------------------------===//
544553
// GPUToGENX
545554
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_subdirectory(FuncToSPIRV)
2525
add_subdirectory(GENToLLVM)
2626
add_subdirectory(GENToSPIRV)
2727
add_subdirectory(GPUCommon)
28+
add_subdirectory(GPUToGEN)
2829
add_subdirectory(GPUToGENX)
2930
add_subdirectory(GPUToNVVM)
3031
add_subdirectory(GPUToROCDL)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_conversion_library(MLIRGPUToGEN
2+
GPUToGEN.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GPUToGEN
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRGPUDialect
15+
MLIRGENDialect
16+
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
//===- GPUToGEN.cpp - GPU to GEN Patterns ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert GPU dialect to GEN dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/GPUToGEN/GPUToGEN.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/GEN/IR/GENDialect.h"
17+
#include "mlir/Dialect/GEN/IR/GENOps.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/IR/MLIRContext.h"
20+
#include "mlir/IR/Matchers.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
24+
#include "llvm/Support/Debug.h"
25+
26+
namespace mlir {
27+
#define GEN_PASS_DEF_CONVERTGPUOPSTOGENOPS
28+
#include "mlir/Conversion/Passes.h.inc"
29+
} // namespace mlir
30+
31+
using namespace mlir;
32+
33+
template <typename GPUOp, typename GENOp>
34+
class GPUIndexOpToGENLowering : public OpConversionPattern<GPUOp> {
35+
public:
36+
using OpConversionPattern<GPUOp>::OpConversionPattern;
37+
using OpAdaptor = typename GPUOp::Adaptor;
38+
39+
LogicalResult
40+
matchAndRewrite(GPUOp op, OpAdaptor adaptor,
41+
ConversionPatternRewriter &rewriter) const final {
42+
auto dim = static_cast<std::uint32_t>(adaptor.getDimension());
43+
Value idxDim = rewriter.create<arith::ConstantIntOp>(op->getLoc(), dim, 32);
44+
rewriter.replaceOpWithNewOp<GENOp>(op, rewriter.getIndexType(), idxDim);
45+
return success();
46+
}
47+
};
48+
49+
class GPUBarrierToGENLowering : public OpConversionPattern<gpu::BarrierOp> {
50+
public:
51+
using OpConversionPattern<gpu::BarrierOp>::OpConversionPattern;
52+
using OpAdaptor = typename gpu::BarrierOp::Adaptor;
53+
54+
LogicalResult match(gpu::BarrierOp op) const final { return success(); }
55+
56+
void rewrite(gpu::BarrierOp op, OpAdaptor,
57+
ConversionPatternRewriter &rewriter) const final {
58+
rewriter.replaceOpWithNewOp<GEN::BarrierOp>(op);
59+
}
60+
};
61+
62+
class GPUShuffleToGENLowering : public OpConversionPattern<gpu::ShuffleOp> {
63+
public:
64+
using OpConversionPattern<gpu::ShuffleOp>::OpConversionPattern;
65+
using OpAdaptor = typename gpu::ShuffleOp::Adaptor;
66+
67+
LogicalResult
68+
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
69+
ConversionPatternRewriter &rewriter) const final {
70+
71+
auto gpuMode = adaptor.getMode();
72+
const auto genMode = [](gpu::ShuffleMode mode) {
73+
switch (mode) {
74+
case gpu::ShuffleMode::XOR:
75+
return GEN::ShflKind::XOR;
76+
case gpu::ShuffleMode::DOWN:
77+
return GEN::ShflKind::DOWN;
78+
case gpu::ShuffleMode::UP:
79+
return GEN::ShflKind::UP;
80+
case gpu::ShuffleMode::IDX:
81+
return GEN::ShflKind::IDX;
82+
}
83+
llvm_unreachable("expected a matching shuffle mode");
84+
}(gpuMode);
85+
86+
// TODO unable to validate gpu width parameter, potential for producing
87+
// invalid code
88+
IntegerAttr widthAttr;
89+
if (!matchPattern(adaptor.getWidth(), m_Constant(&widthAttr))) {
90+
return rewriter.notifyMatchFailure(
91+
op, "shuffle width must be a constant value");
92+
}
93+
94+
Value trueValue = rewriter.create<arith::ConstantOp>(
95+
op->getLoc(), rewriter.getBoolAttr(true));
96+
auto result = rewriter.create<GEN::SubGroupShuffleOp>(
97+
op->getLoc(), op->getResult(0).getType(), adaptor.getValue(),
98+
adaptor.getOffset(), genMode);
99+
100+
rewriter.replaceOp(op, {result, trueValue});
101+
return success();
102+
}
103+
};
104+
105+
void mlir::populateGPUToGENPatterns(RewritePatternSet &patterns) {
106+
patterns.add<GPUIndexOpToGENLowering<gpu::ThreadIdOp, GEN::LocalIdOp>,
107+
GPUIndexOpToGENLowering<gpu::BlockIdOp, GEN::WorkGroupIdOp>,
108+
GPUIndexOpToGENLowering<gpu::BlockDimOp, GEN::WorkGroupSizeOp>,
109+
GPUIndexOpToGENLowering<gpu::GridDimOp, GEN::NumWorkGroupsOp>,
110+
GPUBarrierToGENLowering, GPUShuffleToGENLowering>(
111+
patterns.getContext());
112+
}
113+
114+
namespace {
115+
struct ConvertGpuOpsToGENOpsPass
116+
: public impl::ConvertGpuOpsToGENOpsBase<ConvertGpuOpsToGENOpsPass> {
117+
void runOnOperation() override {
118+
ConversionTarget target(getContext());
119+
120+
target.addLegalOp<arith::ConstantOp>();
121+
target.addLegalDialect<GEN::GENDialect>();
122+
// The ops of gpu dialect that can currently be mapped to GEN
123+
target.addIllegalOp<gpu::ThreadIdOp, gpu::BlockIdOp, gpu::BlockDimOp,
124+
gpu::GridDimOp, gpu::BarrierOp, gpu::ShuffleOp>();
125+
126+
mlir::RewritePatternSet patterns(&getContext());
127+
populateGPUToGENPatterns(patterns);
128+
129+
if (failed(applyPartialConversion(getOperation(), target,
130+
std::move(patterns))))
131+
signalPassFailure();
132+
}
133+
};
134+
} // namespace
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// RUN: mlir-opt -split-input-file -convert-gpu-to-gen %s | FileCheck %s
2+
3+
gpu.module @local_id_kernels {
4+
// CHECK-LABEL: gen_local_id_x
5+
gpu.func @gen_local_id_x() kernel {
6+
// CHECK: [[DIM:%.*]] = arith.constant 0 : i32
7+
// CHECK: gen.local_id [[DIM]]
8+
%0 = gpu.thread_id x
9+
gpu.return
10+
}
11+
12+
// CHECK-LABEL: gen_local_id_y
13+
gpu.func @gen_local_id_y() kernel {
14+
// CHECK: [[DIM:%.*]] = arith.constant 1 : i32
15+
// CHECK: gen.local_id [[DIM]]
16+
%0 = gpu.thread_id y
17+
gpu.return
18+
}
19+
20+
// CHECK-LABEL: gen_local_id_z
21+
gpu.func @gen_local_id_z() kernel {
22+
// CHECK: [[DIM:%.*]] = arith.constant 2 : i32
23+
// CHECK: gen.local_id [[DIM]]
24+
%0 = gpu.thread_id z
25+
gpu.return
26+
}
27+
}
28+
29+
// -----
30+
31+
32+
gpu.module @work_group_id_kernels {
33+
// CHECK-LABEL: gen_work_group_id_x
34+
gpu.func @gen_work_group_id_x() kernel {
35+
// CHECK: [[DIM:%.*]] = arith.constant 0 : i32
36+
// CHECK: gen.work_group_id [[DIM]]
37+
%0 = gpu.block_id x
38+
gpu.return
39+
}
40+
41+
// CHECK-LABEL: gen_work_group_id_y
42+
gpu.func @gen_work_group_id_y() kernel {
43+
// CHECK: [[DIM:%.*]] = arith.constant 1 : i32
44+
// CHECK: gen.work_group_id [[DIM]]
45+
%0 = gpu.block_id y
46+
gpu.return
47+
}
48+
49+
// CHECK-LABEL: gen_work_group_id_z
50+
gpu.func @gen_work_group_id_z() kernel {
51+
// CHECK: [[DIM:%.*]] = arith.constant 2 : i32
52+
// CHECK: gen.work_group_id [[DIM]]
53+
%0 = gpu.block_id z
54+
gpu.return
55+
}
56+
}
57+
58+
// -----
59+
60+
61+
gpu.module @work_group_size_kernels {
62+
// CHECK-LABEL: gen_work_group_size_x
63+
gpu.func @gen_work_group_size_x() kernel {
64+
// CHECK: [[DIM:%.*]] = arith.constant 0 : i32
65+
// CHECK: gen.work_group_size [[DIM]]
66+
%0 = gpu.block_dim x
67+
gpu.return
68+
}
69+
70+
// CHECK-LABEL: gen_work_group_size_y
71+
gpu.func @gen_work_group_size_y() kernel {
72+
// CHECK: [[DIM:%.*]] = arith.constant 1 : i32
73+
// CHECK: gen.work_group_size [[DIM]]
74+
%0 = gpu.block_dim y
75+
gpu.return
76+
}
77+
78+
// CHECK-LABEL: gen_work_group_size_z
79+
gpu.func @gen_work_group_size_z() kernel {
80+
// CHECK: [[DIM:%.*]] = arith.constant 2 : i32
81+
// CHECK: gen.work_group_size [[DIM]]
82+
%0 = gpu.block_dim z
83+
gpu.return
84+
}
85+
}
86+
87+
// -----
88+
89+
90+
gpu.module @num_work_groups_kernels {
91+
// CHECK-LABEL: gen_num_work_groups_x
92+
gpu.func @gen_num_work_groups_x() kernel {
93+
// CHECK: [[DIM:%.*]] = arith.constant 0 : i32
94+
// CHECK: gen.num_work_groups [[DIM]]
95+
%0 = gpu.grid_dim x
96+
gpu.return
97+
}
98+
99+
// CHECK-LABEL: gen_num_work_groups_y
100+
gpu.func @gen_num_work_groups_y() kernel {
101+
// CHECK: [[DIM:%.*]] = arith.constant 1 : i32
102+
// CHECK: gen.num_work_groups [[DIM]]
103+
%0 = gpu.grid_dim y
104+
gpu.return
105+
}
106+
107+
// CHECK-LABEL: gen_num_work_groups_z
108+
gpu.func @gen_num_work_groups_z() kernel {
109+
// CHECK: [[DIM:%.*]] = arith.constant 2 : i32
110+
// CHECK: gen.num_work_groups [[DIM]]
111+
%0 = gpu.grid_dim z
112+
gpu.return
113+
}
114+
}
115+
116+
// -----
117+
118+
gpu.module @barrier_kernels {
119+
// CHECK-LABEL: gen_barrier
120+
gpu.func @gen_barrier() kernel {
121+
// CHECK: gen.barrier
122+
gpu.barrier
123+
gpu.return
124+
}
125+
}
126+
127+
// -----
128+
129+
// CHECK-LABEL gpu.module @shuffle_kernels
130+
gpu.module @shuffle_kernels {
131+
// CHECK: gpu.func @gen_shuffle_xor(%[[IN_XOR:.*]]: f32, %[[OFFSET_XOR:.*]]: i32) kernel {
132+
gpu.func @gen_shuffle_xor(%in : f32, %offset: i32) kernel {
133+
// CHECK: %{{.*}} = gen.sub_group_shuffle xor %[[IN_XOR]], %[[OFFSET_XOR]] : f32
134+
%width = arith.constant 32 : i32
135+
%0, %1 = gpu.shuffle xor %in, %offset, %width : f32
136+
gpu.return
137+
}
138+
// CHECK: gpu.func @gen_shuffle_up(%[[IN_UP:.*]]: f32, %[[OFFSET_UP:.*]]: i32) kernel {
139+
gpu.func @gen_shuffle_up(%in : f32, %offset: i32) kernel {
140+
// CHECK: %{{.*}} = gen.sub_group_shuffle up %[[IN_UP]], %[[OFFSET_UP]] : f32
141+
%width = arith.constant 32 : i32
142+
%0, %1 = gpu.shuffle up %in, %offset, %width : f32
143+
gpu.return
144+
}
145+
// CHECK: gpu.func @gen_shuffle_down(%[[IN_DOWN:.*]]: f32, %[[OFFSET_DOWN:.*]]: i32) kernel {
146+
gpu.func @gen_shuffle_down(%in : f32, %offset: i32) kernel {
147+
// CHECK: %{{.*}} = gen.sub_group_shuffle down %[[IN_DOWN]], %[[OFFSET_DOWN]] : f32
148+
%width = arith.constant 32 : i32
149+
%0, %1 = gpu.shuffle down %in, %offset, %width : f32
150+
gpu.return
151+
}
152+
// CHECK: gpu.func @gen_shuffle_idx(%[[IN_IDX:.*]]: f32, %[[OFFSET_IDX:.*]]: i32) kernel {
153+
gpu.func @gen_shuffle_idx(%in : f32, %offset: i32) kernel {
154+
// CHECK: %{{.*}} = gen.sub_group_shuffle idx %[[IN_IDX]], %[[OFFSET_IDX]] : f32
155+
%width = arith.constant 32 : i32
156+
%0, %1 = gpu.shuffle idx %in, %offset, %width : f32
157+
gpu.return
158+
}
159+
}

0 commit comments

Comments
 (0)