8
8
9
9
#include " mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
10
10
11
- #include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
12
- #include " mlir/Conversion/LLVMCommon/Pattern.h"
13
- #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
14
11
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15
- #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
16
12
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
17
13
#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
18
14
#include " mlir/IR/BuiltinTypes.h"
23
19
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
24
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
25
21
22
+ #include " mlir/Transforms/WalkPatternRewriteDriver.h"
26
23
#include " llvm/Support/FormatVariadic.h"
27
- #include " llvm/Support/MathExtras.h"
28
- #include < cassert>
29
- #include < cstdint>
30
-
31
- #include " ../LLVMCommon/MemRefDescriptor.h"
32
-
33
- #include " llvm/ADT/STLExtras.h"
34
- #include < optional>
35
24
36
25
namespace mlir {
37
26
#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
@@ -180,24 +169,17 @@ struct ConvertGPUToAMDGPUPass
180
169
181
170
void runOnOperation () override {
182
171
RewritePatternSet patterns (&getContext ());
183
- LLVMTypeConverter converter (&getContext ());
184
- LLVMConversionTarget target (getContext ());
185
- target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
186
- target.addLegalDialect <::mlir::amdgpu::AMDGPUDialect>();
187
- target.addLegalDialect <::mlir::ROCDL::ROCDLDialect>();
188
-
189
172
int subgroupSizeInt = static_cast <int >(subgroupSize);
190
- populateSubgroupReduceLoweringPatterns (converter, patterns, subgroupSizeInt,
173
+ populateSubgroupReduceLoweringPatterns (patterns, subgroupSizeInt,
191
174
PatternBenefit (1 ));
192
- if (failed (applyPartialConversion (getOperation (), target,
193
- std::move (patterns))))
194
- signalPassFailure ();
175
+ walkAndApplyPatterns (getOperation (), std::move (patterns));
195
176
}
196
177
};
197
178
} // namespace
198
179
199
- void mlir::populateSubgroupReduceLoweringPatterns (
200
- LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
180
+ void mlir::populateSubgroupReduceLoweringPatterns (RewritePatternSet &patterns,
181
+ unsigned subgroupSize,
182
+ PatternBenefit benefit) {
201
183
patterns.add <ScalarSubgroupReduceToShuffles>(
202
184
patterns.getContext (), subgroupSize, /* matchClustered=*/ true , benefit);
203
185
}
0 commit comments