Skip to content

Commit 7c4de2e

Browse files
committed
[mlir][StandardToSPIRV] Add support for lowering memref<?xi1> to SPIR-V
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D100452
1 parent 34367dd commit 7c4de2e

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,6 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
375375
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
376376
const SPIRVTypeConverter::Options &options,
377377
MemRefType type) {
378-
if (!type.hasStaticShape()) {
379-
LLVM_DEBUG(llvm::dbgs()
380-
<< type << " dynamic shape on i1 is not supported yet\n");
381-
return nullptr;
382-
}
383-
384378
Optional<spirv::StorageClass> storageClass =
385379
SPIRVTypeConverter::getStorageClassForMemorySpace(
386380
type.getMemorySpaceAsInt());
@@ -411,6 +405,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
411405
return nullptr;
412406
}
413407

408+
if (!type.hasStaticShape()) {
409+
auto arrayType =
410+
spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
411+
return wrapInStructAndGetPointer(arrayType, *storageClass);
412+
}
413+
414414
int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
415415
auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
416416
auto arrayType =

mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,10 @@ module attributes {
511511
// CHECK-SAME: memref<*xi32>
512512
func @unranked_memref(%arg0: memref<*xi32>) { return }
513513

514-
// Check that dynamic dims on i1 are not supported.
515514
// CHECK-LABEL: func @memref_1bit_type
516-
// CHECK-SAME: memref<?xi1>
515+
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
516+
// NOEMU-LABEL: func @memref_1bit_type
517+
// NOEMU-SAME: memref<?xi1>
517518
func @memref_1bit_type(%arg0: memref<?xi1>) { return }
518519

519520
// CHECK-LABEL: func @dynamic_dim_memref

0 commit comments

Comments
 (0)