Skip to content

Commit 208f42f

Browse files
[mlir][NVVM] Add constant memory space identifier (#111141)
Also use these enums in `BasicPtxBuilderInferface.cpp`.
1 parent 8e33ff7 commit 208f42f

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ enum NVVMMemorySpace {
3535
/// Global memory space identifier.
3636
kGlobalMemorySpace = 1,
3737
/// Shared memory space identifier.
38-
kSharedMemorySpace = 3
38+
kSharedMemorySpace = 3,
39+
/// Constant memory space identifier.
40+
kConstantMemorySpace = 4
3941
};
4042

4143
/// Return the element type and number of elements associated with a wmma matrix

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1516

1617
#define DEBUG_TYPE "ptx-builder"
1718
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -26,8 +27,6 @@
2627
using namespace mlir;
2728
using namespace NVVM;
2829

29-
static constexpr int64_t kSharedMemorySpace = 3;
30-
3130
static char getRegisterType(Type type) {
3231
if (type.isInteger(1))
3332
return 'b';
@@ -43,7 +42,7 @@ static char getRegisterType(Type type) {
4342
return 'd';
4443
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
4544
// Shared address spaces is addressed with 32-bit pointers.
46-
if (ptr.getAddressSpace() == kSharedMemorySpace) {
45+
if (ptr.getAddressSpace() == NVVMMemorySpace::kSharedMemorySpace) {
4746
return 'r';
4847
}
4948
return 'l';

0 commit comments

Comments
 (0)