|
18 | 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
19 | 19 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
20 | 20 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
| 21 | +#include "mlir/IR/BuiltinAttributes.h" |
21 | 22 | #include "mlir/IR/BuiltinTypes.h"
|
22 | 23 | #include "mlir/IR/FunctionInterfaces.h"
|
23 | 24 | #include "mlir/Transforms/DialectConversion.h"
|
@@ -56,7 +57,18 @@ using namespace mlir;
|
56 | 57 | MAP_FN(spirv::StorageClass::Output, 10)
|
57 | 58 |
|
58 | 59 | Optional<spirv::StorageClass>
|
59 |
| -spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { |
| 60 | +spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) { |
| 61 | + // Handle null memory space attribute specially. |
| 62 | + if (!memorySpaceAttr) |
| 63 | + return spirv::StorageClass::StorageBuffer; |
| 64 | + |
| 65 | + // Unknown dialect custom attributes are not supported by default. |
| 66 | + // Downstream callers should plug in more specialized ones. |
| 67 | + auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>(); |
| 68 | + if (!intAttr) |
| 69 | + return llvm::None; |
| 70 | + unsigned memorySpace = intAttr.getInt(); |
| 71 | + |
60 | 72 | #define STORAGE_SPACE_MAP_FN(storage, space) \
|
61 | 73 | case space: \
|
62 | 74 | return storage;
|
@@ -99,7 +111,18 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
|
99 | 111 | MAP_FN(spirv::StorageClass::Image, 7)
|
100 | 112 |
|
101 | 113 | Optional<spirv::StorageClass>
|
102 |
| -spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) { |
| 114 | +spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) { |
| 115 | + // Handle null memory space attribute specially. |
| 116 | + if (!memorySpaceAttr) |
| 117 | + return spirv::StorageClass::CrossWorkgroup; |
| 118 | + |
| 119 | + // Unknown dialect custom attributes are not supported by default. |
| 120 | + // Downstream callers should plug in more specialized ones. |
| 121 | + auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>(); |
| 122 | + if (!intAttr) |
| 123 | + return llvm::None; |
| 124 | + unsigned memorySpace = intAttr.getInt(); |
| 125 | + |
103 | 126 | #define STORAGE_SPACE_MAP_FN(storage, space) \
|
104 | 127 | case space: \
|
105 | 128 | return storage;
|
@@ -143,17 +166,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
|
143 | 166 | addConversion([](Type type) { return type; });
|
144 | 167 |
|
145 | 168 | addConversion([this](BaseMemRefType memRefType) -> Optional<Type> {
|
146 |
| - // Expect IntegerAttr memory spaces. The attribute can be missing for the |
147 |
| - // case of memory space == 0. |
148 |
| - Attribute spaceAttr = memRefType.getMemorySpace(); |
149 |
| - if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) { |
150 |
| - LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType |
151 |
| - << " due to non-IntegerAttr memory space\n"); |
152 |
| - return llvm::None; |
153 |
| - } |
154 |
| - |
155 |
| - unsigned space = memRefType.getMemorySpaceAsInt(); |
156 |
| - auto storage = this->memorySpaceMap(space); |
| 169 | + Optional<spirv::StorageClass> storage = |
| 170 | + this->memorySpaceMap(memRefType.getMemorySpace()); |
157 | 171 | if (!storage) {
|
158 | 172 | LLVM_DEBUG(llvm::dbgs()
|
159 | 173 | << "cannot convert " << memRefType
|
|
0 commit comments