Skip to content

Commit 47e953e

Browse files
committed
[mlir][spirv] Support attribute in MapMemRefStorageClassPass
MemRef memory space actually can be an attribute. Update the map function signature to accept an attribute. The default mappings can still only covers numeric ones, but this allows downstream callers to extend with custom memory spaces. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D138257
1 parent fe5b9a6 commit 47e953e

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ class SPIRVTypeConverter;
2323
namespace spirv {
2424
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
2525
using MemorySpaceToStorageClassMap =
26-
std::function<Optional<spirv::StorageClass>(unsigned)>;
26+
std::function<Optional<spirv::StorageClass>(Attribute)>;
2727

2828
/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
2929
/// using the default rule. Returns None if the memory space is unknown.
30-
Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
30+
Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(Attribute);
3131
/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
3232
/// using the default rule. Returns None if the storage class is unsupported.
3333
Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
3434

3535
/// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V
3636
/// using the default rule. Returns None if the memory space is unknown.
37-
Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(unsigned);
37+
Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(Attribute);
3838
/// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces
3939
/// using the default rule. Returns None if the storage class is unsupported.
4040
Optional<unsigned> mapOpenCLStorageClassToMemorySpace(spirv::StorageClass);

mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1919
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2020
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21+
#include "mlir/IR/BuiltinAttributes.h"
2122
#include "mlir/IR/BuiltinTypes.h"
2223
#include "mlir/IR/FunctionInterfaces.h"
2324
#include "mlir/Transforms/DialectConversion.h"
@@ -56,7 +57,18 @@ using namespace mlir;
5657
MAP_FN(spirv::StorageClass::Output, 10)
5758

5859
Optional<spirv::StorageClass>
59-
spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
60+
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
61+
// Handle null memory space attribute specially.
62+
if (!memorySpaceAttr)
63+
return spirv::StorageClass::StorageBuffer;
64+
65+
// Unknown dialect custom attributes are not supported by default.
66+
// Downstream callers should plug in more specialized ones.
67+
auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
68+
if (!intAttr)
69+
return llvm::None;
70+
unsigned memorySpace = intAttr.getInt();
71+
6072
#define STORAGE_SPACE_MAP_FN(storage, space) \
6173
case space: \
6274
return storage;
@@ -99,7 +111,18 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
99111
MAP_FN(spirv::StorageClass::Image, 7)
100112

101113
Optional<spirv::StorageClass>
102-
spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) {
114+
spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
115+
// Handle null memory space attribute specially.
116+
if (!memorySpaceAttr)
117+
return spirv::StorageClass::CrossWorkgroup;
118+
119+
// Unknown dialect custom attributes are not supported by default.
120+
// Downstream callers should plug in more specialized ones.
121+
auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
122+
if (!intAttr)
123+
return llvm::None;
124+
unsigned memorySpace = intAttr.getInt();
125+
103126
#define STORAGE_SPACE_MAP_FN(storage, space) \
104127
case space: \
105128
return storage;
@@ -143,17 +166,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
143166
addConversion([](Type type) { return type; });
144167

145168
addConversion([this](BaseMemRefType memRefType) -> Optional<Type> {
146-
// Expect IntegerAttr memory spaces. The attribute can be missing for the
147-
// case of memory space == 0.
148-
Attribute spaceAttr = memRefType.getMemorySpace();
149-
if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) {
150-
LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
151-
<< " due to non-IntegerAttr memory space\n");
152-
return llvm::None;
153-
}
154-
155-
unsigned space = memRefType.getMemorySpaceAsInt();
156-
auto storage = this->memorySpaceMap(space);
169+
Optional<spirv::StorageClass> storage =
170+
this->memorySpaceMap(memRefType.getMemorySpace());
157171
if (!storage) {
158172
LLVM_DEBUG(llvm::dbgs()
159173
<< "cannot convert " << memRefType

0 commit comments

Comments
 (0)