Skip to content

Commit 4defac9

Browse files
[mlir][GPUToNVVM] Add benefit to populate functions (#128484)
Certain GPU->NVVM patterns compete with Arith->LLVM patterns. (The ones that lower to libdevice.) Add an optional `benefit` parameter to all `populate` functions so that users can give preference to GPU->NVVM patterns.
1 parent 5bddadf commit 4defac9

File tree

10 files changed

+161
-122
lines changed

10 files changed

+161
-122
lines changed

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13+
#include "mlir/IR/PatternMatch.h"
1314
#include <memory>
1415

1516
namespace mlir {
@@ -35,18 +36,27 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
3536
/// GPU dialect to NVVM.
3637
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
3738

39+
/// Populate patterns that lower certain arith and math dialect ops to
40+
/// libdevice calls.
41+
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
42+
RewritePatternSet &patterns,
43+
PatternBenefit benefit = 1);
44+
3845
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
3946
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
40-
RewritePatternSet &patterns);
47+
RewritePatternSet &patterns,
48+
PatternBenefit benefit = 1);
4149

4250
/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm
4351
/// op that is not available on every GPU.
4452
void populateGpuSubgroupReduceOpLoweringPattern(
45-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
53+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
54+
PatternBenefit benefit = 1);
4655

4756
/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
4857
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter,
49-
RewritePatternSet &patterns);
58+
RewritePatternSet &patterns,
59+
PatternBenefit benefit = 1);
5060
} // namespace mlir
5161

5262
#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def ApplyGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
2626
Collects patterns that convert GPU dialect ops to NVVM dialect ops. These
2727
patterns require an "LLVMTypeConverter".
2828
}];
29+
let arguments = (ins DefaultValuedAttr<I16Attr, "1">:$benefit);
2930
let assemblyFormat = "attr-dict";
3031
}
3132

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ struct GPUDynamicSharedMemoryOpLowering
4343
using ConvertOpToLLVMPattern<
4444
gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
4545
GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
46-
unsigned alignmentBit = 0)
47-
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
46+
unsigned alignmentBit = 0,
47+
PatternBenefit benefit = 1)
48+
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter, benefit),
4849
alignmentBit(alignmentBit) {}
4950

5051
LogicalResult
@@ -81,8 +82,9 @@ struct GPUFuncOpLoweringOptions {
8182

8283
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
8384
GPUFuncOpLowering(const LLVMTypeConverter &converter,
84-
const GPUFuncOpLoweringOptions &options)
85-
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
85+
const GPUFuncOpLoweringOptions &options,
86+
PatternBenefit benefit = 1)
87+
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter, benefit),
8688
allocaAddrSpace(options.allocaAddrSpace),
8789
workgroupAddrSpace(options.workgroupAddrSpace),
8890
kernelAttributeName(options.kernelAttributeName),

mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
3636
IntrType intrType;
3737

3838
public:
39-
explicit OpLowering(const LLVMTypeConverter &typeConverter)
40-
: ConvertOpToLLVMPattern<Op>(typeConverter),
39+
explicit OpLowering(const LLVMTypeConverter &typeConverter,
40+
PatternBenefit benefit = 1)
41+
: ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
4142
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
4243
indexKind(IndexKind::Other), intrType(IntrType::None) {}
4344

4445
explicit OpLowering(const LLVMTypeConverter &typeConverter,
45-
IndexKind indexKind, IntrType intrType)
46-
: ConvertOpToLLVMPattern<Op>(typeConverter),
46+
IndexKind indexKind, IntrType intrType,
47+
PatternBenefit benefit = 1)
48+
: ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
4749
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
4850
indexKind(indexKind), intrType(intrType) {}
4951

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
5757
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
5858
StringRef f32Func, StringRef f64Func,
5959
StringRef f32ApproxFunc, StringRef f16Func,
60-
StringRef i32Func = "")
61-
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
60+
StringRef i32Func = "",
61+
PatternBenefit benefit = 1)
62+
: ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
6263
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
6364
i32Func(i32Func) {}
6465

0 commit comments

Comments
 (0)