@@ -316,6 +316,53 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
316
316
}
317
317
};
318
318
319
+ // ===----------------------------------------------------------------------===//
320
+ // Subgroup query ops.
321
+ // ===----------------------------------------------------------------------===//
322
+
323
+ template <typename SubgroupOp>
324
+ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
325
+ using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
326
+ using ConvertToLLVMPattern::getTypeConverter;
327
+
328
+ LogicalResult
329
+ matchAndRewrite (SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
330
+ ConversionPatternRewriter &rewriter) const final {
331
+ constexpr StringRef funcName = [] {
332
+ if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
333
+ return " _Z16get_sub_group_id" ;
334
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
335
+ return " _Z22get_sub_group_local_id" ;
336
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
337
+ return " _Z18get_num_sub_groups" ;
338
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
339
+ return " _Z18get_sub_group_size" ;
340
+ }
341
+ }();
342
+
343
+ Operation *moduleOp =
344
+ op->template getParentWithTrait <OpTrait::SymbolTable>();
345
+ Type resultTy = rewriter.getI32Type ();
346
+ LLVM::LLVMFuncOp func =
347
+ lookupOrCreateSPIRVFn (moduleOp, funcName, {}, resultTy,
348
+ /* isMemNone=*/ false , /* isConvergent=*/ false );
349
+
350
+ Location loc = op->getLoc ();
351
+ Value result = createSPIRVBuiltinCall (loc, rewriter, func, {}).getResult ();
352
+
353
+ Type indexTy = getTypeConverter ()->getIndexType ();
354
+ if (resultTy != indexTy) {
355
+ if (indexTy.getIntOrFloatBitWidth () < resultTy.getIntOrFloatBitWidth ()) {
356
+ return failure ();
357
+ }
358
+ result = rewriter.create <LLVM::ZExtOp>(loc, indexTy, result);
359
+ }
360
+
361
+ rewriter.replaceOp (op, result);
362
+ return success ();
363
+ }
364
+ };
365
+
319
366
// ===----------------------------------------------------------------------===//
320
367
// GPU To LLVM-SPV Pass.
321
368
// ===----------------------------------------------------------------------===//
@@ -337,7 +384,9 @@ struct GPUToLLVMSPVConversionPass final
337
384
338
385
target.addIllegalOp <gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
339
386
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
340
- gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>();
387
+ gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
388
+ gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
389
+ gpu::ThreadIdOp>();
341
390
342
391
populateGpuToLLVMSPVConversionPatterns (converter, patterns);
343
392
populateGpuMemorySpaceAttributeConversions (converter);
@@ -366,11 +415,15 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
366
415
void populateGpuToLLVMSPVConversionPatterns (LLVMTypeConverter &typeConverter,
367
416
RewritePatternSet &patterns) {
368
417
patterns.add <GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
418
+ GPUSubgroupOpConversion<gpu::LaneIdOp>,
419
+ GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
420
+ GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
421
+ GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
422
+ LaunchConfigOpConversion<gpu::BlockDimOp>,
369
423
LaunchConfigOpConversion<gpu::BlockIdOp>,
424
+ LaunchConfigOpConversion<gpu::GlobalIdOp>,
370
425
LaunchConfigOpConversion<gpu::GridDimOp>,
371
- LaunchConfigOpConversion<gpu::BlockDimOp>,
372
- LaunchConfigOpConversion<gpu::ThreadIdOp>,
373
- LaunchConfigOpConversion<gpu::GlobalIdOp>>(typeConverter);
426
+ LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
374
427
MLIRContext *context = &typeConverter.getContext ();
375
428
unsigned privateAddressSpace =
376
429
gpuAddressSpaceToOCLAddressSpace (gpu::AddressSpace::Private);
0 commit comments