Skip to content

[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op #78510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:

```mlir
%0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
%0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
```
}];

Expand All @@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);

let assemblyFormat = [{
$format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
$format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
}];

let hasVerifier = 0;
Expand Down
130 changes: 129 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};

class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
}
};

// Formulate a unique variable/constant name after
// searching in the module for existing variable/constant names.
// This is to avoid name collision with existing variables.
// Example: printfMsg0, printfMsg1, printfMsg2, ...
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
std::string name;
unsigned number = 0;

do {
name.clear();
name = (prefix + llvm::Twine(number++)).str();
} while (moduleOp.lookupSymbol(name));

return name;
}

/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.

LogicalResult GPUPrintfConversion::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = gpuPrintfOp.getLoc();

auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
if (!moduleOp)
return failure();

// SPIR-V global variable is used to initialize printf
// format string value, if there are multiple printf messages,
// each global var needs to be created with a unique name.
std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
spirv::GlobalVariableOp globalVar;

IntegerType i8Type = rewriter.getI8Type();
IntegerType i32Type = rewriter.getI32Type();

// Each character of printf format string is
// stored as a spec constant. We need to create
// unique name for this spec constant like
// @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
// for existing spec constant names.
auto createSpecConstant = [&](unsigned value) {
auto attr = rewriter.getI8IntegerAttr(value);
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");

return rewriter.create<spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());

ConversionPatternRewriter::InsertionGuard guard(rewriter);

Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(
&entryBlock); // insertion point at module level

// Create Constituents with SpecConstant by scanning format string
// Each character of format string is stored as a spec constant
// and then these spec constants are used to create a
// SpecConstantCompositeOp.
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C.
SmallVector<Attribute, 4> constituents;
for (char c : formatString) {
spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
}

// Create SpecConstantCompositeOp to initialize the global variable
size_t contentSize = constituents.size();
auto globalType = spirv::ArrayType::get(i8Type, contentSize);
spirv::SpecConstantCompositeOp specCstComposite;
// There will be one SpecConstantCompositeOp per printf message/global var,
// so no need do lookup for existing ones.
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();

specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));

auto ptrType = spirv::PointerType::get(
globalType, spirv::StorageClass::UniformConstant);

// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
globalVar = rewriter.create<spirv::GlobalVariableOp>(
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));

globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
Value fmtStr = rewriter.create<spirv::BitcastOp>(
loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);

// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());

rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);

// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
// printf op.
rewriter.eraseOp(gpuPrintfOp);

return success();
}

//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
Expand All @@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
patterns.getContext());
}
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/printf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
} {
func.func @main() {
%c1 = arith.constant 1 : index

gpu.launch_func @kernels::@printf
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args()
return
}

gpu.module @kernels {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf() kernel
attributes
{spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : !spirv.ptr<i8, UniformConstant> -> i32
gpu.printf "\nHello\n"
// CHECK: spirv.Return
gpu.return
}
}
}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
} {
func.func @main() {
%c1 = arith.constant 1 : index
%c100 = arith.constant 100: i32
%cst_f32 = arith.constant 314.4: f32

gpu.launch_func @kernels1::@printf_args
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args(%c100: i32, %cst_f32: f32)
return
}

gpu.module @kernels1 {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = gpu.thread_id x

// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, f32, i32 -> i32
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index

// CHECK: spirv.Return
gpu.return
}
}
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
// spirv.CL.printf
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
// CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
%0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
func.func @printf(%fmt : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
// CHECK: spirv.CL.printf {{%.*}} {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
%0 = spirv.CL.printf %fmt %arg1, %arg2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
return %0 : i32
}

Expand Down
Loading