Skip to content

Commit 888717e

Browse files
[mlir][transform] Enable gpu-to-nvvm via conversion patterns driven by TD
This revision untangles a few more conversion pieces and allows rewriting the relatively intricate (and somewhat inconsistent) LowerGpuOpsToNVVMOpsPass in a declarative fashion that provides a much better understanding and control. Differential Revision: https://reviews.llvm.org/D157617
1 parent 5bf8de8 commit 888717e

File tree

12 files changed

+336
-81
lines changed

12 files changed

+336
-81
lines changed

mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ struct LogicalResult;
3030
class ModuleOp;
3131
class Operation;
3232
class RewritePatternSet;
33+
class TypeConverter;
3334

3435
class Pass;
3536

3637
namespace gpu {
38+
enum class AddressSpace : uint32_t;
3739
class GPUModuleOp;
3840
} // namespace gpu
3941

@@ -69,6 +71,13 @@ void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
6971
StringRef gpuBinaryAnnotation = {},
7072
bool kernelBarePtrCallConv = false);
7173

74+
/// A function that maps a MemorySpace enum to a target-specific integer value.
75+
using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;
76+
77+
/// Populates memory space attribute conversion rules for lowering
78+
/// gpu.address_space to integer values.
79+
void populateGpuMemorySpaceAttributeConversions(
80+
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
7281
} // namespace mlir
7382

7483
#endif // MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,61 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
1414
include "mlir/Interfaces/SideEffectInterfaces.td"
1515
include "mlir/IR/OpBase.td"
1616

17+
//===----------------------------------------------------------------------===//
18+
// Apply...ConversionPatternsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
def ApplyGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
22+
"apply_conversion_patterns.gpu.gpu_to_nvvm",
23+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
24+
["verifyTypeConverter"]>]> {
25+
let description = [{
26+
Collects patterns that convert GPU dialect ops to NVVM dialect ops. These
27+
patterns require an "LLVMTypeConverter".
28+
}];
29+
let assemblyFormat = "attr-dict";
30+
}
31+
32+
def ApplyGPUWwmaToNVVMConversionPatternsOp : Op<Transform_Dialect,
33+
"apply_conversion_patterns.gpu.gpu_wmma_to_nvvm",
34+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
35+
["verifyTypeConverter"]>]> {
36+
let description = [{
37+
Collects patterns that convert GPU dialect ops related to wmma ops
38+
to NVVM dialect ops.
39+
These patterns require an "LLVMTypeConverter".
40+
}];
41+
let assemblyFormat = "attr-dict";
42+
}
43+
44+
def ApplyGPUSubgroupReduceToNVVMConversionPatternsOp : Op<Transform_Dialect,
45+
"apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm",
46+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
47+
["verifyTypeConverter"]>]> {
48+
let description = [{
49+
Collects patterns that convert GPU dialect ops related to wmma ops
50+
to NVVM dialect ops.
51+
These patterns require an "LLVMTypeConverter".
52+
}];
53+
let assemblyFormat = "attr-dict";
54+
}
55+
56+
//===----------------------------------------------------------------------===//
57+
// Apply...PatternsOp
58+
//===----------------------------------------------------------------------===//
59+
60+
def ApplyGPURewritePatternsOp : Op<Transform_Dialect,
61+
"apply_patterns.gpu.gpu_rewrite_patterns",
62+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
63+
let description = [{
64+
Collects GPU rewrite patterns comprising:
65+
1. GpuAllReduceRewrite patterns
66+
2. GpuGlobalIdRewriter patterns
67+
3. GpuShuffleRewriter patterns
68+
}];
69+
let assemblyFormat = "attr-dict";
70+
}
71+
1772
def ApplyUnrollVectorsSubgroupMmaOp : Op<Transform_Dialect,
1873
"apply_patterns.gpu.unroll_vectors_subgroup_mma",
1974
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
1717

1818
//===----------------------------------------------------------------------===//
19-
// ApplyNVGPUToNVVMConversionPatternsOp
19+
// Apply...ConversionPatternsOp
2020
//===----------------------------------------------------------------------===//
2121

2222
def ApplyNVGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "GPUOpsLowering.h"
10+
11+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1012
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1113
#include "mlir/IR/Attributes.h"
1214
#include "mlir/IR/Builders.h"

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,6 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
111111
*this->getTypeConverter());
112112
}
113113
};
114-
115-
/// A function that maps a MemorySpace enum to a target-specific integer value.
116-
using MemorySpaceMapping =
117-
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
118-
119-
/// Populates memory space attribute conversion rules for lowering
120-
/// gpu.address_space to integer values.
121-
void populateGpuMemorySpaceAttributeConversions(
122-
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
123114
} // namespace mlir
124115

125116
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1717
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1818
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1920
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2021
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2122
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1818
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1919
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
20+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
2021
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
2122
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2223
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2324
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2425
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2526
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2627
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
27-
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2828
#include "mlir/Dialect/Func/IR/FuncOps.h"
2929
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3030
#include "mlir/Dialect/GPU/Transforms/Passes.h"
3131
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3232
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
3333
#include "mlir/Dialect/Math/IR/Math.h"
34+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
3435
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3536
#include "mlir/IR/BuiltinAttributes.h"
3637
#include "mlir/Pass/Pass.h"

mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,8 @@ add_mlir_dialect_library(MLIRGPUTransformOps
2020
MLIRTransformDialect
2121
MLIRVectorDialect
2222
MLIRVectorTransforms
23+
24+
# ConversionPatterns
25+
MLIRNVGPUToNVVM
26+
MLIRGPUToNVVMTransforms
2327
)

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88

99
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
1010

11+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1114
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1215
#include "mlir/Dialect/Arith/IR/Arith.h"
1316
#include "mlir/Dialect/Func/IR/FuncOps.h"
1417
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1518
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
19+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
20+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1621
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1722
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1823
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -29,6 +34,7 @@
2934
#include "mlir/IR/OpDefinition.h"
3035
#include "mlir/IR/Visitors.h"
3136
#include "mlir/Support/LLVM.h"
37+
#include "mlir/Transforms/DialectConversion.h"
3238
#include "llvm/ADT/STLExtras.h"
3339
#include "llvm/ADT/SmallVector.h"
3440
#include "llvm/ADT/TypeSwitch.h"
@@ -47,6 +53,85 @@ using namespace mlir::transform::gpu;
4753
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
4854
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
4955

56+
//===----------------------------------------------------------------------===//
57+
// Apply...ConversionPatternsOp
58+
//===----------------------------------------------------------------------===//
59+
60+
void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
61+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
62+
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
63+
// NVVM uses alloca in the default address space to represent private
64+
// memory allocations, so drop private annotations. NVVM uses address
65+
// space 3 for shared memory. NVVM uses the default address space to
66+
// represent global memory.
67+
// Used in populateGpuToNVVMConversionPatternsso attaching here for now.
68+
// TODO: We should have a single to_nvvm_type_converter.
69+
populateGpuMemorySpaceAttributeConversions(
70+
llvmTypeConverter, [](AddressSpace space) -> unsigned {
71+
switch (space) {
72+
case AddressSpace::Global:
73+
return static_cast<unsigned>(
74+
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
75+
case AddressSpace::Workgroup:
76+
return static_cast<unsigned>(
77+
NVVM::NVVMMemorySpace::kSharedMemorySpace);
78+
case AddressSpace::Private:
79+
return 0;
80+
}
81+
llvm_unreachable("unknown address space enum value");
82+
return 0;
83+
});
84+
// Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
85+
// TODO: We should have a single to_nvvm_type_converter.
86+
llvmTypeConverter.addConversion(
87+
[&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
88+
populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
89+
}
90+
91+
LogicalResult
92+
transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
93+
transform::TypeConverterBuilderOpInterface builder) {
94+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
95+
return emitOpError("expected LLVMTypeConverter");
96+
return success();
97+
}
98+
99+
void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
100+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
101+
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
102+
populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
103+
}
104+
105+
LogicalResult
106+
transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
107+
transform::TypeConverterBuilderOpInterface builder) {
108+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
109+
return emitOpError("expected LLVMTypeConverter");
110+
return success();
111+
}
112+
113+
void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
114+
populatePatterns(TypeConverter &typeConverter,
115+
RewritePatternSet &patterns) {
116+
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
117+
populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
118+
}
119+
120+
LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
121+
verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
122+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
123+
return emitOpError("expected LLVMTypeConverter");
124+
return success();
125+
}
126+
127+
//===----------------------------------------------------------------------===//
128+
// Apply...PatternsOp
129+
//===----------------------------------------------------------------------===//s
130+
131+
void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
132+
populateGpuRewritePatterns(patterns);
133+
}
134+
50135
//===----------------------------------------------------------------------===//
51136
// ApplyUnrollVectorsSubgroupMmaOp
52137
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ struct GpuAllReduceRewriter {
387387
static constexpr int kSubgroupSize = 32;
388388
};
389389

390-
struct GpuAllReduceConversion : public RewritePattern {
391-
explicit GpuAllReduceConversion(MLIRContext *context)
390+
struct GpuAllReduceRewrite : public RewritePattern {
391+
explicit GpuAllReduceRewrite(MLIRContext *context)
392392
: RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
393393

394394
LogicalResult matchAndRewrite(Operation *op,
@@ -417,5 +417,5 @@ struct GpuAllReduceConversion : public RewritePattern {
417417
} // namespace
418418

419419
void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
420-
patterns.add<GpuAllReduceConversion>(patterns.getContext());
420+
patterns.add<GpuAllReduceRewrite>(patterns.getContext());
421421
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32 use-opaque-pointers=1' -split-input-file | FileCheck %s
2+
3+
// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
4+
5+
gpu.module @test_module_0 {
6+
// CHECK-LABEL: func @gpu_index_ops()
7+
func.func @gpu_index_ops()
8+
-> (index, index, index, index, index, index,
9+
index, index, index, index, index, index,
10+
index) {
11+
%tIdX = gpu.thread_id x
12+
%tIdY = gpu.thread_id y
13+
%tIdZ = gpu.thread_id z
14+
15+
%bDimX = gpu.block_dim x
16+
%bDimY = gpu.block_dim y
17+
%bDimZ = gpu.block_dim z
18+
19+
%bIdX = gpu.block_id x
20+
%bIdY = gpu.block_id y
21+
%bIdZ = gpu.block_id z
22+
23+
%gDimX = gpu.grid_dim x
24+
%gDimY = gpu.grid_dim y
25+
%gDimZ = gpu.grid_dim z
26+
27+
// CHECK-NOT: = llvm.sext %{{.*}} : i32 to i64
28+
%laneId = gpu.lane_id
29+
30+
func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
31+
%bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
32+
%laneId
33+
: index, index, index, index, index, index,
34+
index, index, index, index, index, index,
35+
index
36+
}
37+
}
38+
39+
40+
41+
gpu.module @test_module_1 {
42+
// CHECK-LABEL: func @gpu_index_comp
43+
func.func @gpu_index_comp(%idx : index) -> index {
44+
// CHECK: = llvm.add %{{.*}}, %{{.*}} : i32
45+
%0 = arith.addi %idx, %idx : index
46+
// CHECK: llvm.return %{{.*}} : i32
47+
func.return %0 : index
48+
}
49+
}
50+
51+
transform.sequence failures(propagate) {
52+
^bb1(%toplevel_module: !transform.any_op):
53+
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
54+
: (!transform.any_op) -> !transform.any_op
55+
transform.apply_conversion_patterns to %gpu_module {
56+
transform.apply_conversion_patterns.dialect_to_llvm "arith"
57+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
58+
transform.apply_conversion_patterns.vector.vector_to_llvm
59+
transform.apply_conversion_patterns.func.func_to_llvm
60+
transform.apply_conversion_patterns.dialect_to_llvm "memref"
61+
transform.apply_conversion_patterns.gpu.gpu_to_nvvm
62+
transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
63+
transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm {has_redux = true}
64+
transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm
65+
} with type_converter {
66+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
67+
{index_bitwidth = 32, use_opaque_pointers = true}
68+
} {
69+
legal_dialects = ["llvm", "memref", "nvvm"],
70+
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
71+
illegal_dialects = ["gpu"],
72+
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
73+
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
74+
"llvm.sin", "llvm.sqrt"],
75+
partial_conversion
76+
} : !transform.any_op
77+
}

0 commit comments

Comments
 (0)