Skip to content

Commit 7fc792c

Browse files
authored
[MLIR] Enable GPU Dialect to SYCL runtime integration (#71430)
GPU Dialect lowering to SYCL runtime is driven by spirv.target_env attached to gpu.module. As a result of this, spirv.target_env remains as an input to LLVMIR Translation. A SPIRVToLLVMIRTranslation without any actual translation is added to avoid an unregistered error in mlir-cpu-runner. SelectObjectAttr.cpp is updated to 1) Pass binary size argument to getModuleLoadFn 2) Pass parameter count to getKernelLaunchFn This change does not impact CUDA and ROCM usage since both mlir_cuda_runtime and mlir_rocm_runtime are already updated to accept and ignore the extra arguments.
1 parent 2284771 commit 7fc792c

File tree

15 files changed

+322
-17
lines changed

15 files changed

+322
-17
lines changed

mlir/include/mlir/Target/LLVMIR/Dialect/All.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
2727
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
2828
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
29+
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
2930
#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h"
3031

3132
namespace mlir {
@@ -45,6 +46,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
4546
registerOpenACCDialectTranslation(registry);
4647
registerOpenMPDialectTranslation(registry);
4748
registerROCDLDialectTranslation(registry);
49+
registerSPIRVDialectTranslation(registry);
4850
registerX86VectorDialectTranslation(registry);
4951

5052
// Extension required for translating GPU offloading Ops.
@@ -61,6 +63,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
6163
registerLLVMDialectTranslation(registry);
6264
registerNVVMDialectTranslation(registry);
6365
registerROCDLDialectTranslation(registry);
66+
registerSPIRVDialectTranslation(registry);
6467

6568
// Extension required for translating GPU offloading Ops.
6669
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- SPIRVToLLVMIRTranslation.h - SPIR-V to LLVM IR -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides registration calls for SPIR-V dialect to LLVM IR translation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TARGET_LLVMIR_DIALECT_SPIRV_SPIRVTOLLVMIRTRANSLATION_H
14+
#define MLIR_TARGET_LLVMIR_DIALECT_SPIRV_SPIRVTOLLVMIRTRANSLATION_H
15+
16+
namespace mlir {
17+
18+
class DialectRegistry;
19+
class MLIRContext;
20+
21+
/// Register the SPIR-V dialect and the translation from it to the LLVM IR in
22+
/// the given registry;
23+
void registerSPIRVDialectTranslation(DialectRegistry &registry);
24+
25+
/// Register the SPIR-V dialect and the translation from it in the registry
26+
/// associated with the given context.
27+
void registerSPIRVDialectTranslation(MLIRContext &context);
28+
29+
} // namespace mlir
30+
31+
#endif // MLIR_TARGET_LLVMIR_DIALECT_SPIRV_SPIRVTOLLVMIRTRANSLATION_H

mlir/lib/Target/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
5858
MLIROpenACCToLLVMIRTranslation
5959
MLIROpenMPToLLVMIRTranslation
6060
MLIRROCDLToLLVMIRTranslation
61+
MLIRSPIRVToLLVMIRTranslation
6162
)
6263

6364
add_mlir_translation_library(MLIRTargetLLVMIRImport

mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ add_subdirectory(NVVM)
99
add_subdirectory(OpenACC)
1010
add_subdirectory(OpenMP)
1111
add_subdirectory(ROCDL)
12+
add_subdirectory(SPIRV)
1213
add_subdirectory(X86Vector)

mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class LaunchKernel {
175175
IRBuilderBase &builder;
176176
mlir::LLVM::ModuleTranslation &moduleTranslation;
177177
Type *i32Ty{};
178+
Type *i64Ty{};
178179
Type *voidTy{};
179180
Type *intPtrTy{};
180181
PointerType *ptrTy{};
@@ -216,6 +217,7 @@ llvm::LaunchKernel::LaunchKernel(
216217
mlir::LLVM::ModuleTranslation &moduleTranslation)
217218
: module(module), builder(builder), moduleTranslation(moduleTranslation) {
218219
i32Ty = builder.getInt32Ty();
220+
i64Ty = builder.getInt64Ty();
219221
ptrTy = builder.getPtrTy(0);
220222
voidTy = builder.getVoidTy();
221223
intPtrTy = builder.getIntPtrTy(module.getDataLayout());
@@ -224,11 +226,11 @@ llvm::LaunchKernel::LaunchKernel(
224226
llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
225227
return module.getOrInsertFunction(
226228
"mgpuLaunchKernel",
227-
FunctionType::get(
228-
voidTy,
229-
ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
230-
intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy}),
231-
false));
229+
FunctionType::get(voidTy,
230+
ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
231+
intPtrTy, intPtrTy, intPtrTy, i32Ty,
232+
ptrTy, ptrTy, ptrTy, i64Ty}),
233+
false));
232234
}
233235

234236
llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
@@ -251,7 +253,7 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
251253
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
252254
return module.getOrInsertFunction(
253255
"mgpuModuleLoad",
254-
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy}), false));
256+
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
255257
}
256258

257259
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
@@ -391,10 +393,24 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
391393
if (!binary)
392394
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
393395

396+
auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
397+
if (!binaryVar)
398+
return op.emitError() << "Binary is not a global variable: "
399+
<< binaryIdentifier;
400+
llvm::Constant *binaryInit = binaryVar->getInitializer();
401+
auto binaryDataSeq =
402+
dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
403+
if (!binaryDataSeq)
404+
return op.emitError() << "Couldn't find binary data array: "
405+
<< binaryIdentifier;
406+
llvm::Constant *binarySize =
407+
llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
408+
binaryDataSeq->getElementByteSize());
409+
394410
Value *moduleObject =
395411
object.getFormat() == gpu::CompilationTarget::Assembly
396412
? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
397-
: builder.CreateCall(getModuleLoadFn(), {binary});
413+
: builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
398414

399415
// Load the kernel function.
400416
Value *moduleFunction = builder.CreateCall(
@@ -413,6 +429,9 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
413429
stream = builder.CreateCall(getStreamCreateFn(), {});
414430
}
415431

432+
llvm::Constant *paramsCount =
433+
llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
434+
416435
// Create the launch call.
417436
Value *nullPtr = ConstantPointerNull::get(ptrTy);
418437

@@ -426,10 +445,10 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
426445
ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
427446
dynamicMemorySize, stream, argArray, nullPtr}));
428447
} else {
429-
builder.CreateCall(
430-
getKernelLaunchFn(),
431-
ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
432-
dynamicMemorySize, stream, argArray, nullPtr}));
448+
builder.CreateCall(getKernelLaunchFn(),
449+
ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
450+
bz, dynamicMemorySize, stream,
451+
argArray, nullPtr, paramsCount}));
433452
}
434453

435454
// Sync & destroy the stream, for synchronous launches.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_translation_library(MLIRSPIRVToLLVMIRTranslation
2+
SPIRVToLLVMIRTranslation.cpp
3+
4+
LINK_COMPONENTS
5+
Core
6+
7+
LINK_LIBS PUBLIC
8+
MLIRIR
9+
MLIRLLVMDialect
10+
MLIRSPIRVDialect
11+
MLIRSupport
12+
MLIRTargetLLVMIRExport
13+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- SPIRVToLLVMIRTranslation.cpp - Translate SPIR-V to LLVM IR ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a translation between the MLIR SPIR-V dialect and
10+
// LLVM IR.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
15+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16+
#include "mlir/IR/BuiltinAttributes.h"
17+
#include "mlir/IR/Operation.h"
18+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
19+
20+
using namespace mlir;
21+
using namespace mlir::LLVM;
22+
23+
void mlir::registerSPIRVDialectTranslation(DialectRegistry &registry) {
24+
registry.insert<spirv::SPIRVDialect>();
25+
}
26+
27+
void mlir::registerSPIRVDialectTranslation(MLIRContext &context) {
28+
DialectRegistry registry;
29+
registerSPIRVDialectTranslation(registry);
30+
context.appendDialectRegistry(registry);
31+
}

mlir/test/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ if(MLIR_ENABLE_ROCM_RUNNER)
142142
list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
143143
endif()
144144

145+
if(MLIR_ENABLE_SYCL_RUNNER)
146+
list(APPEND MLIR_TEST_DEPENDS mlir_sycl_runtime)
147+
endif()
148+
145149
if (MLIR_RUN_ARM_SME_TESTS AND NOT ARM_SME_ABI_ROUTINES_SHLIB)
146150
list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_abi_stubs)
147151
endif()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(spirv-attach-target{ver=v1.0 caps=Addresses,Int64,Kernel},convert-gpu-to-spirv{use-64bit-index=true},gpu.module(spirv.module(spirv-lower-abi-attrs,spirv-update-vce)),func.func(llvm-request-c-wrappers),convert-scf-to-cf,convert-cf-to-llvm,convert-arith-to-llvm,convert-math-to-llvm,convert-func-to-llvm,gpu-to-llvm{use-bare-pointers-for-kernels=true},gpu-module-to-binary,expand-strided-metadata,lower-affine,finalize-memref-to-llvm,reconcile-unrealized-casts)' \
2+
// RUN: | mlir-cpu-runner \
3+
// RUN: --shared-libs=%mlir_sycl_runtime \
4+
// RUN: --shared-libs=%mlir_runner_utils \
5+
// RUN: --entry-point-result=void \
6+
// RUN: | FileCheck %s
7+
8+
module @add attributes {gpu.container_module} {
9+
memref.global "private" constant @__constant_2x2x2xf32_0 : memref<2x2x2xf32> = dense<[[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8 ]]]>
10+
memref.global "private" constant @__constant_2x2x2xf32 : memref<2x2x2xf32> = dense<[[[1.2, 2.3], [4.5, 5.8]], [[7.2, 8.3], [10.5, 11.8]]]>
11+
func.func @main() {
12+
%0 = memref.get_global @__constant_2x2x2xf32 : memref<2x2x2xf32>
13+
%1 = memref.get_global @__constant_2x2x2xf32_0 : memref<2x2x2xf32>
14+
%2 = call @test(%0, %1) : (memref<2x2x2xf32>, memref<2x2x2xf32>) -> memref<2x2x2xf32>
15+
%cast = memref.cast %2 : memref<2x2x2xf32> to memref<*xf32>
16+
call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
17+
return
18+
}
19+
func.func private @printMemrefF32(memref<*xf32>)
20+
func.func @test(%arg0: memref<2x2x2xf32>, %arg1: memref<2x2x2xf32>) -> memref<2x2x2xf32> {
21+
%c2 = arith.constant 2 : index
22+
%c1 = arith.constant 1 : index
23+
%mem = gpu.alloc host_shared () : memref<2x2x2xf32>
24+
memref.copy %arg1, %mem : memref<2x2x2xf32> to memref<2x2x2xf32>
25+
%memref_0 = gpu.alloc host_shared () : memref<2x2x2xf32>
26+
memref.copy %arg0, %memref_0 : memref<2x2x2xf32> to memref<2x2x2xf32>
27+
%memref_2 = gpu.alloc host_shared () : memref<2x2x2xf32>
28+
%2 = gpu.wait async
29+
%3 = gpu.launch_func async [%2] @test_kernel::@test_kernel blocks in (%c2, %c2, %c2) threads in (%c1, %c1, %c1) args(%memref_0 : memref<2x2x2xf32>, %mem : memref<2x2x2xf32>, %memref_2 : memref<2x2x2xf32>)
30+
gpu.wait [%3]
31+
%alloc = memref.alloc() : memref<2x2x2xf32>
32+
memref.copy %memref_2, %alloc : memref<2x2x2xf32> to memref<2x2x2xf32>
33+
%4 = gpu.wait async
34+
%5 = gpu.dealloc async [%4] %memref_2 : memref<2x2x2xf32>
35+
%6 = gpu.dealloc async [%5] %memref_0 : memref<2x2x2xf32>
36+
%7 = gpu.dealloc async [%6] %mem : memref<2x2x2xf32>
37+
gpu.wait [%7]
38+
return %alloc : memref<2x2x2xf32>
39+
}
40+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int64, Kernel], []>, api=OpenCL, #spirv.resource_limits<>>} {
41+
gpu.func @test_kernel(%arg0: memref<2x2x2xf32>, %arg1: memref<2x2x2xf32>, %arg2: memref<2x2x2xf32>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 2, 2, 2>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
42+
%0 = gpu.block_id x
43+
%1 = gpu.block_id y
44+
%2 = gpu.block_id z
45+
%3 = memref.load %arg0[%0, %1, %2] : memref<2x2x2xf32>
46+
%4 = memref.load %arg1[%0, %1, %2] : memref<2x2x2xf32>
47+
%5 = arith.addf %3, %4 : f32
48+
memref.store %5, %arg2[%0, %1, %2] : memref<2x2x2xf32>
49+
gpu.return
50+
}
51+
}
52+
// CHECK: [2.3, 4.5]
53+
// CHECK: [7.8, 10.2]
54+
// CHECK: [12.7, 14.9]
55+
// CHECK: [18.2, 20.6]
56+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(spirv-attach-target{ver=v1.0 caps=Addresses,Int64,Kernel},convert-gpu-to-spirv{use-64bit-index=true},gpu.module(spirv.module(spirv-lower-abi-attrs,spirv-update-vce)),func.func(llvm-request-c-wrappers),convert-scf-to-cf,convert-cf-to-llvm,convert-arith-to-llvm,convert-math-to-llvm,convert-func-to-llvm,gpu-to-llvm{use-bare-pointers-for-kernels=true},gpu-module-to-binary,expand-strided-metadata,lower-affine,finalize-memref-to-llvm,reconcile-unrealized-casts)' \
2+
// RUN: | mlir-cpu-runner \
3+
// RUN: --shared-libs=%mlir_sycl_runtime \
4+
// RUN: --shared-libs=%mlir_runner_utils \
5+
// RUN: --entry-point-result=void \
6+
// RUN: | FileCheck %s
7+
8+
module @add attributes {gpu.container_module} {
9+
memref.global "private" constant @__constant_3x3xi64_0 : memref<3x3xi64> = dense<[[1, 4098, 3], [16777220, 5, 4294967302], [7, 1099511627784, 9]]>
10+
memref.global "private" constant @__constant_3x3xi64 : memref<3x3xi64> = dense<[[1, 2, 3], [4, 5, 4102], [16777223, 4294967304, 1099511627785]]>
11+
func.func @main() {
12+
%0 = memref.get_global @__constant_3x3xi64 : memref<3x3xi64>
13+
%1 = memref.get_global @__constant_3x3xi64_0 : memref<3x3xi64>
14+
%2 = call @test(%0, %1) : (memref<3x3xi64>, memref<3x3xi64>) -> memref<3x3xi64>
15+
%cast = memref.cast %2 : memref<3x3xi64> to memref<*xi64>
16+
call @printMemrefI64(%cast) : (memref<*xi64>) -> ()
17+
return
18+
}
19+
func.func private @printMemrefI64(memref<*xi64>)
20+
func.func @test(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>) -> memref<3x3xi64> {
21+
%c3 = arith.constant 3 : index
22+
%c1 = arith.constant 1 : index
23+
%mem = gpu.alloc host_shared () : memref<3x3xi64>
24+
memref.copy %arg1, %mem : memref<3x3xi64> to memref<3x3xi64>
25+
%memref_0 = gpu.alloc host_shared () : memref<3x3xi64>
26+
memref.copy %arg0, %memref_0 : memref<3x3xi64> to memref<3x3xi64>
27+
%memref_2 = gpu.alloc host_shared () : memref<3x3xi64>
28+
%2 = gpu.wait async
29+
%3 = gpu.launch_func async [%2] @test_kernel::@test_kernel blocks in (%c3, %c3, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<3x3xi64>, %mem : memref<3x3xi64>, %memref_2 : memref<3x3xi64>)
30+
gpu.wait [%3]
31+
%alloc = memref.alloc() : memref<3x3xi64>
32+
memref.copy %memref_2, %alloc : memref<3x3xi64> to memref<3x3xi64>
33+
%4 = gpu.wait async
34+
%5 = gpu.dealloc async [%4] %memref_2 : memref<3x3xi64>
35+
%6 = gpu.dealloc async [%5] %memref_0 : memref<3x3xi64>
36+
%7 = gpu.dealloc async [%6] %mem : memref<3x3xi64>
37+
gpu.wait [%7]
38+
return %alloc : memref<3x3xi64>
39+
}
40+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int64, Kernel], []>, api=OpenCL, #spirv.resource_limits<>>} {
41+
gpu.func @test_kernel(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>, %arg2: memref<3x3xi64>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 3, 3, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
42+
%0 = gpu.block_id x
43+
%1 = gpu.block_id y
44+
%2 = memref.load %arg0[%0, %1] : memref<3x3xi64>
45+
%3 = memref.load %arg1[%0, %1] : memref<3x3xi64>
46+
%4 = arith.addi %2, %3 : i64
47+
memref.store %4, %arg2[%0, %1] : memref<3x3xi64>
48+
gpu.return
49+
}
50+
}
51+
// CHECK: [2, 4100, 6],
52+
// CHECK: [16777224, 10, 4294971404],
53+
// CHECK: [16777230, 1103806595088, 1099511627794]
54+
}

0 commit comments

Comments
 (0)