Skip to content

Commit a862b6d

Browse files
authored
[flang][cuda] Lower shared global to the correct NVVM address space (#131368)
Global with the CUDA shared data attribute needs to be lowered to llvm globals with the correct address space (3). Address space is set from the `mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace` enum from `mlir/Dialect/LLVMIR/NVVMDialect.h`
1 parent fbf0276 commit a862b6d

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
4848
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
4949
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
50+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
5051
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
5152
#include "mlir/Dialect/OpenACC/OpenACC.h"
5253
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -3145,6 +3146,11 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
31453146
}
31463147
}
31473148
}
3149+
3150+
if (global.getDataAttr() &&
3151+
*global.getDataAttr() == cuf::DataAttribute::Shared)
3152+
g.setAddrSpace(mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace);
3153+
31483154
rewriter.eraseOp(global);
31493155
return mlir::success();
31503156
}

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,9 @@ func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.c
198198

199199
// CHECK-LABEL: llvm.func @_QMm1Psub1
200200
// CHECK-COUNT-2: _FortranACUFAllocDescriptor
201+
202+
// -----
203+
204+
fir.global common @_QPshared_static__shared_mem(dense<0> : vector<28xi8>) {alignment = 8 : i64, data_attr = #cuf.cuda<shared>} : !fir.array<28xi8>
205+
206+
// CHECK: llvm.mlir.global common @_QPshared_static__shared_mem(dense<0> : vector<28xi8>) {addr_space = 3 : i32, alignment = 8 : i64} : !llvm.array<28 x i8>

0 commit comments

Comments
 (0)