Skip to content

Commit 8900c09

Browse files
authored
[mlir][nvgpu] Fix crash when handling 0D memref in OptimizeSharedMemoryPass (#124517)
This PR adds a check for 0D memref types to prevent a crash. Fixes #119855.
1 parent 740e6ae commit 8900c09

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
152152
if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
153153
return failure();
154154

155+
// Not support 0D MemRefs.
156+
if (memRefType.getRank() == 0)
157+
return failure();
158+
155159
// Abort if the given value has any sub-views; we do not do any alias
156160
// analysis.
157161
bool hasSubView = false;

mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,13 @@ func.func @abort_if_subview(%arg0: memref<128x128xf16>,
238238

239239
return %mat: vector<1x2xf16>
240240
}
241+
242+
// -----
243+
244+
// Ensure this case not crash
245+
246+
// CHECK-LABEL: func @test_0_d
247+
func.func @test_0_d() -> memref<i32, #gpu.address_space<workgroup>> {
248+
%alloc = memref.alloc() : memref<i32, #gpu.address_space<workgroup>>
249+
return %alloc : memref<i32, #gpu.address_space<workgroup>>
250+
}

0 commit comments

Comments
 (0)