Skip to content

Commit ed5f0c9

Browse files
kurapov-petervictor-edskuhar
authored andcommitted
[MLIR][GPU-LLVM] Add in-pass signature update for opencl kernels (llvm#105664)
Default to Global address space for memrefs that do not have an explicit address space set in the IR. --------- Co-authored-by: Victor Perez <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Victor Perez <[email protected]>
1 parent db72839 commit ed5f0c9

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include "llvm/ADT/TypeSwitch.h"
3535
#include "llvm/Support/FormatVariadic.h"
3636

37+
#define DEBUG_TYPE "gpu-to-llvm-spv"
38+
3739
using namespace mlir;
3840

3941
namespace mlir {
@@ -316,6 +318,38 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
316318
}
317319
};
318320

321+
class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter {
322+
public:
323+
MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
324+
addConversion([](Type t) { return t; });
325+
addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
326+
// Attach global addr space attribute to memrefs with no addr space attr
327+
Attribute memSpaceAttr = memRefType.getMemorySpace();
328+
if (memSpaceAttr)
329+
return std::nullopt;
330+
331+
unsigned globalAddrspace = storageClassToAddressSpace(
332+
spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
333+
Attribute addrSpaceAttr =
334+
IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
335+
if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
336+
return MemRefType::get(memRefType.getShape(),
337+
memRefType.getElementType(),
338+
rankedType.getLayout(), addrSpaceAttr);
339+
}
340+
return UnrankedMemRefType::get(memRefType.getElementType(),
341+
addrSpaceAttr);
342+
});
343+
addConversion([this](FunctionType type) {
344+
auto inputs = llvm::map_to_vector(
345+
type.getInputs(), [this](Type ty) { return convertType(ty); });
346+
auto results = llvm::map_to_vector(
347+
type.getResults(), [this](Type ty) { return convertType(ty); });
348+
return FunctionType::get(type.getContext(), inputs, results);
349+
});
350+
}
351+
};
352+
319353
//===----------------------------------------------------------------------===//
320354
// Subgroup query ops.
321355
//===----------------------------------------------------------------------===//
@@ -382,6 +416,21 @@ struct GPUToLLVMSPVConversionPass final
382416
LLVMTypeConverter converter(context, options);
383417
LLVMConversionTarget target(*context);
384418

419+
// Force OpenCL address spaces when they are not present
420+
{
421+
MemorySpaceToOpenCLMemorySpaceConverter converter(context);
422+
AttrTypeReplacer replacer;
423+
replacer.addReplacement([&converter](BaseMemRefType origType)
424+
-> std::optional<BaseMemRefType> {
425+
return converter.convertType<BaseMemRefType>(origType);
426+
});
427+
428+
replacer.recursivelyReplaceElementsIn(getOperation(),
429+
/*replaceAttrs=*/true,
430+
/*replaceLocs=*/false,
431+
/*replaceTypes=*/true);
432+
}
433+
385434
target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
386435
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
387436
gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -444,20 +444,20 @@ gpu.module @kernels {
444444
gpu.return
445445
}
446446

447-
// CHECK-64: llvm.func spir_kernelcc @kernel_with_conv_args(%{{.*}}: i64, %{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64) attributes {gpu.kernel} {
448-
// CHECK-32: llvm.func spir_kernelcc @kernel_with_conv_args(%{{.*}}: i32, %{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i32) attributes {gpu.kernel} {
447+
// CHECK-64: llvm.func spir_kernelcc @kernel_with_conv_args(%{{.*}}: i64, %{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i64) attributes {gpu.kernel} {
448+
// CHECK-32: llvm.func spir_kernelcc @kernel_with_conv_args(%{{.*}}: i32, %{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i32) attributes {gpu.kernel} {
449449
gpu.func @kernel_with_conv_args(%arg0: index, %arg1: memref<index>) kernel {
450450
gpu.return
451451
}
452452

453-
// CHECK-64: llvm.func spir_kernelcc @kernel_with_sized_memref(%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64) attributes {gpu.kernel} {
454-
// CHECK-32: llvm.func spir_kernelcc @kernel_with_sized_memref(%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) attributes {gpu.kernel} {
453+
// CHECK-64: llvm.func spir_kernelcc @kernel_with_sized_memref(%{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64) attributes {gpu.kernel} {
454+
// CHECK-32: llvm.func spir_kernelcc @kernel_with_sized_memref(%{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) attributes {gpu.kernel} {
455455
gpu.func @kernel_with_sized_memref(%arg0: memref<1xindex>) kernel {
456456
gpu.return
457457
}
458458

459-
// CHECK-64: llvm.func spir_kernelcc @kernel_with_ND_memref(%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64) attributes {gpu.kernel} {
460-
// CHECK-32: llvm.func spir_kernelcc @kernel_with_ND_memref(%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) attributes {gpu.kernel} {
459+
// CHECK-64: llvm.func spir_kernelcc @kernel_with_ND_memref(%{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64) attributes {gpu.kernel} {
460+
// CHECK-32: llvm.func spir_kernelcc @kernel_with_ND_memref(%{{.*}}: !llvm.ptr<1>, %{{.*}}: !llvm.ptr<1>, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) attributes {gpu.kernel} {
461461
gpu.func @kernel_with_ND_memref(%arg0: memref<128x128x128xindex>) kernel {
462462
gpu.return
463463
}
@@ -566,6 +566,44 @@ gpu.module @kernels {
566566

567567
// -----
568568

569+
gpu.module @kernels {
570+
// CHECK: llvm.func spir_funccc @_Z12get_group_idj(i32)
571+
// CHECK-LABEL: llvm.func spir_funccc @no_address_spaces(
572+
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
573+
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
574+
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
575+
gpu.func @no_address_spaces(%arg0: memref<f32>, %arg1: memref<f32, #gpu.address_space<global>>, %arg2: memref<f32>) {
576+
gpu.return
577+
}
578+
579+
// CHECK-LABEL: llvm.func spir_kernelcc @no_address_spaces_complex(
580+
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
581+
// CHECK-SAME: %{{[a-zA-Z_][a-zA-Z0-9_]*}}: !llvm.ptr<1>
582+
// CHECK: func.call @no_address_spaces_callee(%{{[0-9]+}}, %{{[0-9]+}})
583+
// CHECK-SAME: : (memref<2x2xf32, 1>, memref<4xf32, 1>)
584+
gpu.func @no_address_spaces_complex(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) kernel {
585+
func.call @no_address_spaces_callee(%arg0, %arg1) : (memref<2x2xf32>, memref<4xf32>) -> ()
586+
gpu.return
587+
}
588+
// CHECK-LABEL: func.func @no_address_spaces_callee(
589+
// CHECK-SAME: [[ARG0:%.*]]: memref<2x2xf32, 1>
590+
// CHECK-SAME: [[ARG1:%.*]]: memref<4xf32, 1>
591+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
592+
// CHECK: [[I0:%.*]] = llvm.call spir_funccc @_Z12get_group_idj([[C0]]) {
593+
// CHECK-32: [[I1:%.*]] = builtin.unrealized_conversion_cast [[I0]] : i32 to index
594+
// CHECK-64: [[I1:%.*]] = builtin.unrealized_conversion_cast [[I0]] : i64 to index
595+
// CHECK: [[LD:%.*]] = memref.load [[ARG0]]{{\[}}[[I1]], [[I1]]{{\]}} : memref<2x2xf32, 1>
596+
// CHECK: memref.store [[LD]], [[ARG1]]{{\[}}[[I1]]{{\]}} : memref<4xf32, 1>
597+
func.func @no_address_spaces_callee(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) {
598+
%block_id = gpu.block_id x
599+
%0 = memref.load %arg0[%block_id, %block_id] : memref<2x2xf32>
600+
memref.store %0, %arg1[%block_id] : memref<4xf32>
601+
func.return
602+
}
603+
}
604+
605+
// -----
606+
569607
// Lowering of subgroup query operations
570608

571609
// CHECK-DAG: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}

0 commit comments

Comments
 (0)