Skip to content

Commit 9dad32c

Browse files
committed
[mlir][nvgpu] Improve finding module Op to for mbarrier.create
Current transformation expects module op to be two level higher, however, it is not always the case. This work searches module op in a while loop. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D155825
1 parent 70c2e06 commit 9dad32c

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,13 @@ struct NVGPUMBarrierCreateLowering
744744
matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
745745
ConversionPatternRewriter &rewriter) const override {
746746
Operation *funcOp = op->getParentOp();
747-
Operation *mOp = funcOp->getParentOp();
748747
MemRefType barrierType =
749748
createMBarrierMemrefType(rewriter, op.getBarrier().getType());
750749

751750
memref::GlobalOp global;
752-
if (auto moduleOp = dyn_cast<gpu::GPUModuleOp>(mOp))
751+
if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
753752
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
754-
else if (auto moduleOp = dyn_cast<ModuleOp>(mOp))
753+
else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
755754
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
756755

757756
rewriter.setInsertionPoint(op);

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,18 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
635635
func.return
636636
}
637637

638+
// -----
639+
640+
!barrierType = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
641+
module @find_parent{
642+
func.func @main() {
643+
%c1 = arith.constant 1 : index
644+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
645+
threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
646+
// CHECK: memref.get_global @__mbarrier : memref<1xi64, 3>
647+
%barrier = nvgpu.mbarrier.create -> !barrierType
648+
gpu.terminator
649+
}
650+
func.return
651+
}
652+
}

0 commit comments

Comments
 (0)