Skip to content

Commit a1666d6

Browse files
committed
Insert integer attributes instead of the gpu address space
1 parent 71006c6 commit a1666d6

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,15 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
312312

313313
class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter {
314314
public:
315-
MemorySpaceToOpenCLMemorySpaceConverter() {
315+
MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
316316
addConversion([](Type t) { return t; });
317-
addConversion([](BaseMemRefType memRefType) -> std::optional<Type> {
317+
addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
318318
// Attach global addr space attribute to memrefs with no addr space attr
319319
Attribute memSpaceAttr = memRefType.getMemorySpace();
320320
if (memSpaceAttr)
321321
return std::nullopt;
322322

323-
auto addrSpaceAttr = gpu::AddressSpaceAttr::get(
324-
memRefType.getContext(), gpu::AddressSpace::Global);
323+
Attribute addrSpaceAttr = IntegerAttr::get(IntegerType::get(ctx, 64), 1);
325324
if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
326325
return MemRefType::get(memRefType.getShape(),
327326
memRefType.getElementType(),
@@ -361,7 +360,7 @@ struct GPUToLLVMSPVConversionPass final
361360

362361
// Force OpenCL address spaces when they are not present
363362
{
364-
MemorySpaceToOpenCLMemorySpaceConverter converter;
363+
MemorySpaceToOpenCLMemorySpaceConverter converter(context);
365364
AttrTypeReplacer replacer;
366365
replacer.addReplacement([&converter](BaseMemRefType origType)
367366
-> std::optional<BaseMemRefType> {
@@ -379,8 +378,6 @@ struct GPUToLLVMSPVConversionPass final
379378
gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>();
380379

381380
populateGpuToLLVMSPVConversionPatterns(converter, patterns);
382-
populateFuncToLLVMConversionPatterns(converter, patterns);
383-
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
384381
populateGpuMemorySpaceAttributeConversions(converter);
385382

386383
if (failed(applyPartialConversion(getOperation(), target,

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,18 +531,20 @@ gpu.module @kernels {
531531
// CHECK-LABEL: llvm.func spir_kernelcc @no_address_spaces_complex(
532532
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
533533
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
534-
// CHECK: llvm.call @no_address_spaces_callee
534+
// CHECK: func.call @no_address_spaces_callee(%{{[0-9]+}}, %{{[0-9]+}})
535+
// CHECK-SAME: : (memref<2x2xf32, 1>, memref<4xf32, 1>)
535536
gpu.func @no_address_spaces_complex(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) kernel {
536537
func.call @no_address_spaces_callee(%arg0, %arg1) : (memref<2x2xf32>, memref<4xf32>) -> ()
537538
gpu.return
538539
}
539-
// CHECK-LABEL: llvm.func @no_address_spaces_callee(
540-
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
541-
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
540+
// CHECK-LABEL: func.func @no_address_spaces_callee(
541+
// CHECK-SAME: [[ARG0:%.*]]: memref<2x2xf32, 1>
542+
// CHECK-SAME: [[ARG1:%.*]]: memref<4xf32, 1>
542543
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
543-
// CHECK: llvm.call spir_funccc @_Z12get_group_idj([[C0]]) {
544-
// CHECK: [[LD:%.*]] = llvm.load
545-
// CHECK: llvm.store [[LD]]
544+
// CHECK: [[I0:%.*]] = llvm.call spir_funccc @_Z12get_group_idj([[C0]]) {
545+
// CHECK: [[I1:%.*]] = builtin.unrealized_conversion_cast [[I0]] : i64 to index
546+
// CHECK: [[LD:%.*]] = memref.load [[ARG0]]{{\[}}[[I1]], [[I1]]{{\]}} : memref<2x2xf32, 1>
547+
// CHECK: memref.store [[LD]], [[ARG1]]{{\[}}[[I1]]{{\]}} : memref<4xf32, 1>
546548
func.func @no_address_spaces_callee(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) {
547549
%block_id = gpu.block_id x
548550
%0 = memref.load %arg0[%block_id, %block_id] : memref<2x2xf32>

0 commit comments

Comments
 (0)