Skip to content

Commit 2dace04

Browse files
authored
[mlir][spirv] Implement gpu::TargetAttrInterface (#69949)
This commit implements gpu::TargetAttrInterface for SPIR-V target attribute. The plan is to use this to enable GPU compilation pipeline for OpenCL kernels later. The changes do not impact Vulkan shaders using milr-vulkan-runner. New GPU Dialect transform pass spirv-attach-target is implemented for attaching attribute from CLI. gpu-module-to-binary pass now works with GPU module that has SPIR-V module with OpenCL kernel functions inside.
1 parent a54545b commit 2dace04

File tree

14 files changed

+355
-0
lines changed

14 files changed

+355
-0
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "Utils.h"
1717
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include <optional>
2021

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,48 @@ def GpuROCDLAttachTarget: Pass<"rocdl-attach-target", ""> {
188188
];
189189
}
190190

191+
def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
192+
let summary = "Attaches an SPIR-V target attribute to a GPU Module.";
193+
let description = [{
194+
This pass searches for all GPU Modules in the immediate regions and attaches
195+
an SPIR-V target if the module matches the name specified by the `module` argument.
196+
197+
Example:
198+
```
199+
// Given the following file: in1.mlir:
200+
gpu.module @nvvm_module_1 {...}
201+
gpu.module @spirv_module_1 {...}
202+
// With
203+
// mlir-opt --spirv-attach-target="module=spirv.* ver=v1.0 caps=Kernel" in1.mlir
204+
// it will generate,
205+
gpu.module @nvvm_module_1 {...}
206+
gpu.module @spirv_module_1 [#spirv.target<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>] {...}
207+
```
208+
}];
209+
let options = [
210+
Option<"moduleMatcher", "module", "std::string",
211+
/*default=*/ [{""}],
212+
"Regex used to identify the modules to attach the target to.">,
213+
Option<"spirvVersion", "ver", "std::string",
214+
/*default=*/ "\"v1.0\"",
215+
"SPIR-V Version.">,
216+
ListOption<"spirvCapabilities", "caps", "std::string",
217+
"List of supported SPIR-V Capabilities">,
218+
ListOption<"spirvExtensions", "exts", "std::string",
219+
"List of supported SPIR-V Extensions">,
220+
Option<"clientApi", "client_api", "std::string",
221+
/*default=*/ "\"Unknown\"",
222+
"Client API">,
223+
Option<"deviceVendor", "vendor", "std::string",
224+
/*default=*/ "\"Unknown\"",
225+
"Device Vendor">,
226+
Option<"deviceType", "device_type", "std::string",
227+
/*default=*/ "\"Unknown\"",
228+
"Device Type">,
229+
Option<"deviceId", "device_id", "uint32_t",
230+
/*default=*/ "mlir::spirv::TargetEnvAttr::kUnknownDeviceID",
231+
"Device ID">,
232+
];
233+
}
234+
191235
#endif // MLIR_DIALECT_GPU_PASSES

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
#include "mlir/IR/BuiltinAttributes.h"
1818
#include "mlir/Support/LLVM.h"
1919

20+
namespace mlir {
21+
namespace spirv {
22+
class VerCapExtAttr;
23+
}
24+
} // namespace mlir
25+
2026
// Pull in TableGen'erated SPIR-V attribute definitions for target and ABI.
2127
#define GET_ATTRDEF_CLASSES
2228
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h.inc"

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include "mlir/Interfaces/CastInterfaces.h"
9292
#include "mlir/Target/LLVM/NVVM/Target.h"
9393
#include "mlir/Target/LLVM/ROCDL/Target.h"
94+
#include "mlir/Target/SPIRV/Target.h"
9495

9596
namespace mlir {
9697

@@ -175,6 +176,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
175176
vector::registerSubsetOpInterfaceExternalModels(registry);
176177
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
177178
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
179+
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
178180
}
179181

180182
/// Append all the MLIR dialects to the registry contained in the given context.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- Target.h - MLIR SPIR-V target registration ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides registration calls for attaching the SPIR-V target interface.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TARGET_SPIRV_TARGET_H
14+
#define MLIR_TARGET_SPIRV_TARGET_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
class MLIRContext;
19+
namespace spirv {
20+
/// Registers the `TargetAttrInterface` for the `#spirv.target_env` attribute in
21+
/// the given registry.
22+
void registerSPIRVTargetInterfaceExternalModels(DialectRegistry &registry);
23+
24+
/// Registers the `TargetAttrInterface` for the `#spirv.target_env` attribute in
25+
/// the registry associated with the given context.
26+
void registerSPIRVTargetInterfaceExternalModels(MLIRContext &context);
27+
} // namespace spirv
28+
} // namespace mlir
29+
30+
#endif // MLIR_TARGET_SPIRV_TARGET_H

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
6060
Transforms/SerializeToCubin.cpp
6161
Transforms/SerializeToHsaco.cpp
6262
Transforms/ShuffleRewriter.cpp
63+
Transforms/SPIRVAttachTarget.cpp
6364
Transforms/ROCDLAttachTarget.cpp
6465

6566
ADDITIONAL_HEADER_DIRS
@@ -95,6 +96,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
9596
MLIRPass
9697
MLIRSCFDialect
9798
MLIRSideEffectInterfaces
99+
MLIRSPIRVTarget
98100
MLIRSupport
99101
MLIRROCDLTarget
100102
MLIRTransformUtils

mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1919
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2020
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
21+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2122
#include "mlir/IR/BuiltinOps.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2324

@@ -53,6 +54,7 @@ void GpuModuleToBinaryPass::getDependentDialects(
5354
#if MLIR_ROCM_CONVERSIONS_ENABLED == 1
5455
registry.insert<ROCDL::ROCDLDialect>();
5556
#endif
57+
registry.insert<spirv::SPIRVDialect>();
5658
}
5759

5860
void GpuModuleToBinaryPass::runOnOperation() {
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the `GPUSPIRVAttachTarget` pass, attaching
10+
// `#spirv.target_env` attributes to GPU modules.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19+
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Target/SPIRV/Target.h"
23+
#include "llvm/Support/Regex.h"
24+
25+
namespace mlir {
26+
#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27+
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28+
} // namespace mlir
29+
30+
using namespace mlir;
31+
using namespace mlir::spirv;
32+
33+
namespace {
34+
struct SPIRVAttachTarget
35+
: public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
36+
using Base::Base;
37+
38+
void runOnOperation() override;
39+
40+
void getDependentDialects(DialectRegistry &registry) const override {
41+
registry.insert<spirv::SPIRVDialect>();
42+
}
43+
};
44+
} // namespace
45+
46+
void SPIRVAttachTarget::runOnOperation() {
47+
OpBuilder builder(&getContext());
48+
auto versionSymbol = symbolizeVersion(spirvVersion);
49+
if (!versionSymbol)
50+
return signalPassFailure();
51+
auto apiSymbol = symbolizeClientAPI(clientApi);
52+
if (!apiSymbol)
53+
return signalPassFailure();
54+
auto vendorSymbol = symbolizeVendor(deviceVendor);
55+
if (!vendorSymbol)
56+
return signalPassFailure();
57+
auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58+
if (!deviceTypeSymbol)
59+
return signalPassFailure();
60+
61+
Version version = versionSymbol.value();
62+
SmallVector<Capability, 4> capabilities;
63+
SmallVector<Extension, 8> extensions;
64+
for (auto cap : spirvCapabilities) {
65+
auto capSymbol = symbolizeCapability(cap);
66+
if (capSymbol)
67+
capabilities.push_back(capSymbol.value());
68+
}
69+
ArrayRef<Capability> caps(capabilities);
70+
for (auto ext : spirvExtensions) {
71+
auto extSymbol = symbolizeExtension(ext);
72+
if (extSymbol)
73+
extensions.push_back(extSymbol.value());
74+
}
75+
ArrayRef<Extension> exts(extensions);
76+
VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
77+
auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()),
78+
apiSymbol.value(), vendorSymbol.value(),
79+
deviceTypeSymbol.value(), deviceId);
80+
llvm::Regex matcher(moduleMatcher);
81+
getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
82+
// Check if the name of the module matches.
83+
if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
84+
return;
85+
// Create the target array.
86+
SmallVector<Attribute> targets;
87+
if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
88+
targets.append(attrs->getValue().begin(), attrs->getValue().end());
89+
targets.push_back(target);
90+
// Remove any duplicate targets.
91+
targets.erase(std::unique(targets.begin(), targets.end()), targets.end());
92+
// Update the target attribute array.
93+
gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
94+
});
95+
}

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "SPIRVParsingUtils.h"
1616

17+
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1718
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1819
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1920
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -133,6 +134,7 @@ void SPIRVDialect::initialize() {
133134

134135
// Allow unknown operations because SPIR-V is extensible.
135136
allowUnknownOperations();
137+
declarePromisedInterface<TargetEnvAttr, gpu::TargetAttrInterface>();
136138
}
137139

138140
std::string SPIRVDialect::getAttributeName(Decoration decoration) {

mlir/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(Serialization)
44
set(LLVM_OPTIONAL_SOURCES
55
SPIRVBinaryUtils.cpp
66
TranslateRegistration.cpp
7+
Target.cpp
78
)
89

910
add_mlir_translation_library(MLIRSPIRVBinaryUtils
@@ -26,3 +27,15 @@ add_mlir_translation_library(MLIRSPIRVTranslateRegistration
2627
MLIRSupport
2728
MLIRTranslateLib
2829
)
30+
31+
add_mlir_dialect_library(MLIRSPIRVTarget
32+
Target.cpp
33+
34+
LINK_LIBS PUBLIC
35+
MLIRIR
36+
MLIRSPIRVDialect
37+
MLIRSPIRVSerialization
38+
MLIRSPIRVDeserialization
39+
MLIRSupport
40+
MLIRTranslateLib
41+
)

mlir/lib/Target/SPIRV/Target.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===- Target.cpp - MLIR SPIR-V target compilation --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files defines SPIR-V target related functions including registration
10+
// calls for the `#spirv.target_env` compilation attribute.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Target/SPIRV/Target.h"
15+
16+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
20+
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
21+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
22+
#include "mlir/Target/LLVMIR/Export.h"
23+
#include "mlir/Target/SPIRV/Serialization.h"
24+
25+
#include "llvm/Support/FileSystem.h"
26+
#include "llvm/Support/FileUtilities.h"
27+
#include "llvm/Support/FormatVariadic.h"
28+
#include "llvm/Support/MemoryBuffer.h"
29+
#include "llvm/Support/Path.h"
30+
#include "llvm/Support/Process.h"
31+
#include "llvm/Support/Program.h"
32+
#include "llvm/Support/TargetSelect.h"
33+
34+
#include <cstdlib>
35+
#include <cstring>
36+
37+
using namespace mlir;
38+
using namespace mlir::spirv;
39+
40+
namespace {
41+
// SPIR-V implementation of the gpu:TargetAttrInterface.
42+
class SPIRVTargetAttrImpl
43+
: public gpu::TargetAttrInterface::FallbackModel<SPIRVTargetAttrImpl> {
44+
public:
45+
std::optional<SmallVector<char, 0>>
46+
serializeToObject(Attribute attribute, Operation *module,
47+
const gpu::TargetOptions &options) const;
48+
49+
Attribute createObject(Attribute attribute,
50+
const SmallVector<char, 0> &object,
51+
const gpu::TargetOptions &options) const;
52+
};
53+
} // namespace
54+
55+
// Register the SPIR-V dialect, the SPIR-V translation & the target interface.
56+
void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
57+
DialectRegistry &registry) {
58+
registry.addExtension(+[](MLIRContext *ctx, spirv::SPIRVDialect *dialect) {
59+
spirv::TargetEnvAttr::attachInterface<SPIRVTargetAttrImpl>(*ctx);
60+
});
61+
}
62+
63+
void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
64+
MLIRContext &context) {
65+
DialectRegistry registry;
66+
registerSPIRVTargetInterfaceExternalModels(registry);
67+
context.appendDialectRegistry(registry);
68+
}
69+
70+
// Reuse from existing serializer
71+
std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
72+
Attribute attribute, Operation *module,
73+
const gpu::TargetOptions &options) const {
74+
if (!module)
75+
return std::nullopt;
76+
auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
77+
if (!gpuMod) {
78+
module->emitError("expected to be a gpu.module op");
79+
return std::nullopt;
80+
}
81+
auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
82+
if (spvMods.empty())
83+
return std::nullopt;
84+
85+
auto spvMod = *spvMods.begin();
86+
llvm::SmallVector<uint32_t, 0> spvBinary;
87+
88+
spvBinary.clear();
89+
// Serialize the spirv.module op to SPIR-V blob.
90+
if (mlir::failed(spirv::serialize(spvMod, spvBinary))) {
91+
spvMod.emitError() << "failed to serialize SPIR-V module";
92+
return std::nullopt;
93+
}
94+
95+
SmallVector<char, 0> spvData(spvBinary.size() * sizeof(uint32_t), 0);
96+
std::memcpy(spvData.data(), spvBinary.data(), spvData.size());
97+
98+
spvMod.erase();
99+
return spvData;
100+
}
101+
102+
// Prepare Attribute for gpu.binary with serialized kernel object
103+
Attribute
104+
SPIRVTargetAttrImpl::createObject(Attribute attribute,
105+
const SmallVector<char, 0> &object,
106+
const gpu::TargetOptions &options) const {
107+
gpu::CompilationTarget format = options.getCompilationTarget();
108+
DictionaryAttr objectProps;
109+
Builder builder(attribute.getContext());
110+
return builder.getAttr<gpu::ObjectAttr>(
111+
attribute, format,
112+
builder.getStringAttr(StringRef(object.data(), object.size())),
113+
objectProps);
114+
}

0 commit comments

Comments
 (0)