Skip to content

[mlir][spirv] Handle all zero-element memref types #73351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2023
Merged

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Nov 24, 2023

Bail out of type conversion instead of crashing.

Fixes: #73289

Bail out of type conversion instead of crashing.
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

Bail out of type conversion instead of crashing.

Fixes: #73289


Full diff: https://github.com/llvm/llvm-project/pull/73351.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+12)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/alloc.mlir (+2)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..2b79c8022b8e85b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -469,6 +469,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
+  if (type.getNumElements() == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: zero-element memrefs are not supported\n");
+    return nullptr;
+  }
+
   int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
   int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
@@ -500,6 +506,12 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
+  if (type.getNumElements() == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: zero-element memrefs are not supported\n");
+    return nullptr;
+  }
+
   int64_t memrefSize =
       llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
   int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 7037051573bd610..2a5f81544f20a86 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -187,6 +187,8 @@ module attributes {
 {
   func.func @zero_size() {
     %0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<0xi1, #spirv.storage_class<Workgroup>>
+    %2 = memref.alloc() : memref<0xi4, #spirv.storage_class<Workgroup>>
     return
   }
 }

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that 0-element memref is allowed actually. But fine to guard against it.

@kuhar kuhar merged commit 6ba6039 into llvm:main Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] --convert-scf-to-spirv crashed with assertion failure "ArrayType needs at least one element"
3 participants