@@ -52,6 +52,25 @@ namespace mlir {
52
52
53
53
using namespace mlir ;
54
54
55
+ // Truncate or extend the result depending on the index bitwidth specified
56
+ // by the LLVMTypeConverter options.
57
+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
58
+ Location loc, Value value,
59
+ const LLVMTypeConverter &converter) {
60
+ int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
61
+ int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
62
+ auto indexBitwidthType =
63
+ IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
64
+ // TODO: use <=> in C++20.
65
+ if (indexBitwidth > intWidth) {
66
+ return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
67
+ }
68
+ if (indexBitwidth < intWidth) {
69
+ return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
70
+ }
71
+ return value;
72
+ }
73
+
55
74
// / Returns true if the given `gpu.func` can be safely called using the bare
56
75
// / pointer calling convention.
57
76
static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -113,6 +132,35 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113
132
}
114
133
};
115
134
135
+ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
137
+
138
+ GPUSubgroupSizeOpToROCDL (const LLVMTypeConverter &converter,
139
+ amdgpu::Chipset chipset)
140
+ : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter),
141
+ chipset (chipset) {}
142
+
143
+ LogicalResult
144
+ matchAndRewrite (gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
145
+ ConversionPatternRewriter &rewriter) const override {
146
+ LLVM::ConstantRangeAttr bounds = nullptr ;
147
+ bool isBeforeGfx10 = chipset.majorVersion < 10 ;
148
+ if (auto upperBoundAttr = op.getUpperBoundAttr ()) {
149
+ bounds = rewriter.getAttr <LLVM::ConstantRangeAttr>(
150
+ /* bitWidth=*/ 32 , /* lower=*/ isBeforeGfx10 ? 64 : 32 ,
151
+ /* upper=*/ op.getUpperBoundAttr ().getInt () + 1 );
152
+ }
153
+ Value wavefrontOp = rewriter.create <ROCDL::WavefrontSizeOp>(
154
+ op.getLoc (), rewriter.getI32Type (), bounds);
155
+ wavefrontOp = truncOrExtToLLVMType (rewriter, op.getLoc (), wavefrontOp,
156
+ *getTypeConverter ());
157
+ rewriter.replaceOp (op, {wavefrontOp});
158
+ return success ();
159
+ }
160
+
161
+ const amdgpu::Chipset chipset;
162
+ };
163
+
116
164
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern <gpu::ShuffleOp> {
117
165
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118
166
@@ -322,7 +370,8 @@ struct LowerGpuOpsToROCDLOpsPass final
322
370
323
371
populateAMDGPUToROCDLConversionPatterns (converter, llvmPatterns,
324
372
*maybeChipset);
325
- populateGpuToROCDLConversionPatterns (converter, llvmPatterns, runtime);
373
+ populateGpuToROCDLConversionPatterns (converter, llvmPatterns, runtime,
374
+ *maybeChipset);
326
375
configureGpuToROCDLConversionLegality (target);
327
376
if (failed (applyPartialConversion (m, target, std::move (llvmPatterns))))
328
377
signalPassFailure ();
@@ -370,7 +419,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
370
419
371
420
void mlir::populateGpuToROCDLConversionPatterns (
372
421
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
373
- mlir::gpu::amd::Runtime runtime) {
422
+ mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset ) {
374
423
using gpu::index_lowering::IndexKind;
375
424
using gpu::index_lowering::IntrType;
376
425
using mlir::gpu::amd::Runtime;
@@ -408,7 +457,10 @@ void mlir::populateGpuToROCDLConversionPatterns(
408
457
// TODO: Add alignment for workgroup memory
409
458
patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
410
459
411
- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
460
+ patterns
461
+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
462
+ converter);
463
+ patterns.add <GPUSubgroupSizeOpToROCDL>(converter, chipset);
412
464
413
465
populateMathToROCDLConversionPatterns (converter, patterns);
414
466
}
0 commit comments