Skip to content

Commit af7aa22

Browse files
authored
[MLIR][GPU] Lower subgroup query ops in gpu-to-llvm-spv (#108839)
These ops are: * gpu.subgroup_id * gpu.lane_id * gpu.num_subgroups * gpu.subgroup_size --------- Signed-off-by: Finlay Marno <[email protected]>
1 parent 39babbf commit af7aa22

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,53 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
316316
}
317317
};
318318

319+
//===----------------------------------------------------------------------===//
320+
// Subgroup query ops.
321+
//===----------------------------------------------------------------------===//
322+
323+
template <typename SubgroupOp>
324+
struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
325+
using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
326+
using ConvertToLLVMPattern::getTypeConverter;
327+
328+
LogicalResult
329+
matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
330+
ConversionPatternRewriter &rewriter) const final {
331+
constexpr StringRef funcName = [] {
332+
if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
333+
return "_Z16get_sub_group_id";
334+
} else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
335+
return "_Z22get_sub_group_local_id";
336+
} else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
337+
return "_Z18get_num_sub_groups";
338+
} else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
339+
return "_Z18get_sub_group_size";
340+
}
341+
}();
342+
343+
Operation *moduleOp =
344+
op->template getParentWithTrait<OpTrait::SymbolTable>();
345+
Type resultTy = rewriter.getI32Type();
346+
LLVM::LLVMFuncOp func =
347+
lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy,
348+
/*isMemNone=*/false, /*isConvergent=*/false);
349+
350+
Location loc = op->getLoc();
351+
Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
352+
353+
Type indexTy = getTypeConverter()->getIndexType();
354+
if (resultTy != indexTy) {
355+
if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
356+
return failure();
357+
}
358+
result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
359+
}
360+
361+
rewriter.replaceOp(op, result);
362+
return success();
363+
}
364+
};
365+
319366
//===----------------------------------------------------------------------===//
320367
// GPU To LLVM-SPV Pass.
321368
//===----------------------------------------------------------------------===//
@@ -337,7 +384,9 @@ struct GPUToLLVMSPVConversionPass final
337384

338385
target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
339386
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
340-
gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>();
387+
gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
388+
gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
389+
gpu::ThreadIdOp>();
341390

342391
populateGpuToLLVMSPVConversionPatterns(converter, patterns);
343392
populateGpuMemorySpaceAttributeConversions(converter);
@@ -366,11 +415,15 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
366415
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &typeConverter,
367416
RewritePatternSet &patterns) {
368417
patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
418+
GPUSubgroupOpConversion<gpu::LaneIdOp>,
419+
GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
420+
GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
421+
GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
422+
LaunchConfigOpConversion<gpu::BlockDimOp>,
369423
LaunchConfigOpConversion<gpu::BlockIdOp>,
424+
LaunchConfigOpConversion<gpu::GlobalIdOp>,
370425
LaunchConfigOpConversion<gpu::GridDimOp>,
371-
LaunchConfigOpConversion<gpu::BlockDimOp>,
372-
LaunchConfigOpConversion<gpu::ThreadIdOp>,
373-
LaunchConfigOpConversion<gpu::GlobalIdOp>>(typeConverter);
426+
LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
374427
MLIRContext *context = &typeConverter.getContext();
375428
unsigned privateAddressSpace =
376429
gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,36 @@ gpu.module @kernels {
563563
gpu.return
564564
}
565565
}
566+
567+
// -----
568+
569+
// Lowering of subgroup query operations
570+
571+
// CHECK-DAG: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
572+
// CHECK-DAG: llvm.func spir_funccc @_Z18get_num_sub_groups() -> i32 attributes {no_unwind, will_return}
573+
// CHECK-DAG: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
574+
// CHECK-DAG: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
575+
576+
577+
gpu.module @subgroup_operations {
578+
// CHECK-LABEL: @gpu_subgroup
579+
func.func @gpu_subgroup() {
580+
// CHECK: %[[SG_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
581+
// CHECK-32-NOT: llvm.zext
582+
// CHECK-64 %{{.*}} = llvm.zext %[[SG_ID]] : i32 to i64
583+
%0 = gpu.subgroup_id : index
584+
// CHECK: %[[SG_LOCAL_ID:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32
585+
// CHECK-32-NOT: llvm.zext
586+
// CHECK-64: %{{.*}} = llvm.zext %[[SG_LOCAL_ID]] : i32 to i64
587+
%1 = gpu.lane_id
588+
// CHECK: %[[NUM_SGS:.*]] = llvm.call spir_funccc @_Z18get_num_sub_groups() {no_unwind, will_return} : () -> i32
589+
// CHECK-32-NOT: llvm.zext
590+
// CHECK-64: %{{.*}} = llvm.zext %[[NUM_SGS]] : i32 to i64
591+
%2 = gpu.num_subgroups : index
592+
// CHECK: %[[SG_SIZE:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size() {no_unwind, will_return} : () -> i32
593+
// CHECK-32-NOT: llvm.zext
594+
// CHECK-64: %{{.*}} = llvm.zext %[[SG_SIZE]] : i32 to i64
595+
%3 = gpu.subgroup_size : index
596+
return
597+
}
598+
}

0 commit comments

Comments
 (0)