Skip to content

Commit b40f238

Browse files
simplifying pattern application
1 parent a5b8a27 commit b40f238

File tree

2 files changed

+7
-26
lines changed

2 files changed

+7
-26
lines changed

mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ class Pass;
2323
#define GEN_PASS_DECL_CONVERTGPUTOAMDGPUPASS
2424
#include "mlir/Conversion/Passes.h.inc"
2525

26-
void populateSubgroupReduceLoweringPatterns(LLVMTypeConverter &converter,
27-
RewritePatternSet &patterns,
26+
void populateSubgroupReduceLoweringPatterns(RewritePatternSet &patterns,
2827
unsigned subgroupSize,
2928
PatternBenefit benefit);
3029
// void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,

mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88

99
#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
1010

11-
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12-
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13-
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1411
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1612
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1713
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1814
#include "mlir/IR/BuiltinTypes.h"
@@ -23,15 +19,8 @@
2319
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2420
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2521

22+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
2623
#include "llvm/Support/FormatVariadic.h"
27-
#include "llvm/Support/MathExtras.h"
28-
#include <cassert>
29-
#include <cstdint>
30-
31-
#include "../LLVMCommon/MemRefDescriptor.h"
32-
33-
#include "llvm/ADT/STLExtras.h"
34-
#include <optional>
3524

3625
namespace mlir {
3726
#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
@@ -180,24 +169,17 @@ struct ConvertGPUToAMDGPUPass
180169

181170
void runOnOperation() override {
182171
RewritePatternSet patterns(&getContext());
183-
LLVMTypeConverter converter(&getContext());
184-
LLVMConversionTarget target(getContext());
185-
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
186-
target.addLegalDialect<::mlir::amdgpu::AMDGPUDialect>();
187-
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
188-
189172
int subgroupSizeInt = static_cast<int>(subgroupSize);
190-
populateSubgroupReduceLoweringPatterns(converter, patterns, subgroupSizeInt,
173+
populateSubgroupReduceLoweringPatterns(patterns, subgroupSizeInt,
191174
PatternBenefit(1));
192-
if (failed(applyPartialConversion(getOperation(), target,
193-
std::move(patterns))))
194-
signalPassFailure();
175+
walkAndApplyPatterns(getOperation(), std::move(patterns));
195176
}
196177
};
197178
} // namespace
198179

199-
void mlir::populateSubgroupReduceLoweringPatterns(
200-
LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
180+
void mlir::populateSubgroupReduceLoweringPatterns(RewritePatternSet &patterns,
181+
unsigned subgroupSize,
182+
PatternBenefit benefit) {
201183
patterns.add<ScalarSubgroupReduceToShuffles>(
202184
patterns.getContext(), subgroupSize, /*matchClustered=*/true, benefit);
203185
}

0 commit comments

Comments
 (0)