Skip to content

Commit 812d1fa

Browse files
committed
[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op
This change contains following: - adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass. - Fixes Constant decoration parsing for spirv GlobalVariable. - minor modification to spirv.CL.printf op assembly format.
1 parent 00b6d03 commit 812d1fa

File tree

6 files changed

+190
-5
lines changed

6 files changed

+190
-5
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
875875
#### Example:
876876

877877
```mlir
878-
%0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
878+
%0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
879879
```
880880
}];
881881

@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
889889
);
890890

891891
let assemblyFormat = [{
892-
$format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
892+
$format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
893893
}];
894894

895895
let hasVerifier = 0;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
135135
ConversionPatternRewriter &rewriter) const override;
136136
};
137137

138+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
139+
public:
140+
using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
141+
142+
LogicalResult
143+
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
144+
ConversionPatternRewriter &rewriter) const override;
145+
};
146+
138147
} // namespace
139148

140149
//===----------------------------------------------------------------------===//
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
607616
}
608617
};
609618

619+
LogicalResult GPUPrintfConversion::matchAndRewrite(
620+
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
621+
ConversionPatternRewriter &rewriter) const {
622+
623+
auto loc = gpuPrintfOp.getLoc();
624+
625+
auto funcOp =
626+
rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
627+
628+
auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
629+
630+
const char formatStringPrefix[] = "printfMsg";
631+
unsigned stringNumber = 0;
632+
mlir::SmallString<16> globalVarName;
633+
mlir::spirv::GlobalVariableOp globalVar;
634+
635+
// formulate spirv global variable name
636+
do {
637+
globalVarName.clear();
638+
(formatStringPrefix + llvm::Twine(stringNumber++))
639+
.toStringRef(globalVarName);
640+
} while (moduleOp.lookupSymbol(globalVarName));
641+
642+
auto i8Type = rewriter.getI8Type();
643+
auto i32Type = rewriter.getI32Type();
644+
645+
unsigned scNum = 0;
646+
auto createSpecConstant = [&](unsigned value) {
647+
auto attr = rewriter.getI8IntegerAttr(value);
648+
mlir::SmallString<16> specCstName;
649+
(llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
650+
.toStringRef(specCstName);
651+
652+
return rewriter.create<mlir::spirv::SpecConstantOp>(
653+
loc, rewriter.getStringAttr(specCstName), attr);
654+
};
655+
656+
// define GlobalVarOp with printf format string using SpecConstants
657+
// and make composite of SpecConstants
658+
{
659+
mlir::Operation *parent =
660+
mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
661+
662+
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
663+
664+
mlir::Block &entryBlock = *parent->getRegion(0).begin();
665+
rewriter.setInsertionPointToStart(
666+
&entryBlock); // insertion point at module level
667+
668+
// Create Constituents with SpecConstant to construct
669+
// SpecConstantCompositeOp
670+
llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
671+
formatString.push_back('\0'); // Null terminate for C
672+
mlir::SmallVector<mlir::Attribute, 4> constituents;
673+
for (auto c : formatString) {
674+
auto cSpecConstantOp = createSpecConstant(c);
675+
constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
676+
}
677+
678+
// Create specialization constant composite defined via spirv.SpecConstant
679+
size_t contentSize = constituents.size();
680+
auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
681+
mlir::spirv::SpecConstantCompositeOp specCstComposite;
682+
mlir::SmallString<16> specCstCompositeName;
683+
(llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
684+
specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
685+
loc, mlir::TypeAttr::get(globalType),
686+
rewriter.getStringAttr(specCstCompositeName),
687+
rewriter.getArrayAttr(constituents));
688+
689+
// Define GlobalVariable initialized from Constant Composite
690+
globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
691+
loc,
692+
mlir::spirv::PointerType::get(
693+
globalType, mlir::spirv::StorageClass::UniformConstant),
694+
globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
695+
696+
globalVar->setAttr("Constant", rewriter.getUnitAttr());
697+
}
698+
699+
// Get SSA value of Global variable
700+
mlir::Value globalPtr =
701+
rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
702+
mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
703+
loc,
704+
mlir::spirv::PointerType::get(i8Type,
705+
mlir::spirv::StorageClass::UniformConstant),
706+
globalPtr);
707+
708+
// Get printf arguments
709+
auto argsRange = adaptor.getArgs();
710+
mlir::SmallVector<mlir::Value, 4> printfArgs;
711+
printfArgs.reserve(argsRange.size() + 1);
712+
printfArgs.append(argsRange.begin(), argsRange.end());
713+
714+
rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
715+
716+
rewriter.eraseOp(gpuPrintfOp);
717+
718+
return mlir::success();
719+
}
720+
610721
//===----------------------------------------------------------------------===//
611722
// GPU To SPIRV Patterns.
612723
//===----------------------------------------------------------------------===//
@@ -630,5 +741,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
630741
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
631742
spirv::BuiltIn::SubgroupSize>,
632743
WorkGroupSizeConversion, GPUAllReduceConversion,
633-
GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
744+
GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
745+
patterns.getContext());
634746
}

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
309309
case spirv::Decoration::RelaxedPrecision:
310310
case spirv::Decoration::Restrict:
311311
case spirv::Decoration::RestrictPointer:
312+
case spirv::Decoration::Constant:
312313
if (words.size() != 2) {
313314
return emitError(unknownLoc, "OpDecoration with ")
314315
<< decorationName << "needs a single target <id>";

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
272272
case spirv::Decoration::RelaxedPrecision:
273273
case spirv::Decoration::Restrict:
274274
case spirv::Decoration::RestrictPointer:
275+
case spirv::Decoration::Constant:
275276
// For unit attributes and decoration attributes, the args list
276277
// has no values so we do nothing.
277278
if (isa<UnitAttr, DecorationAttr>(attr))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
6+
} {
7+
func.func @main() {
8+
%c1 = arith.constant 1 : index
9+
10+
gpu.launch_func @kernels::@printf
11+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
12+
args()
13+
return
14+
}
15+
16+
gpu.module @kernels {
17+
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
18+
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
19+
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
20+
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
21+
gpu.func @printf() kernel
22+
attributes
23+
{spirv.entry_point_abi = #spirv.entry_point_abi<>} {
24+
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
25+
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
26+
// CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : (!spirv.ptr<i8, UniformConstant>) -> i32
27+
gpu.printf "\nHello\n"
28+
// CHECK: spirv.Return
29+
gpu.return
30+
}
31+
}
32+
}
33+
34+
// -----
35+
36+
module attributes {
37+
gpu.container_module,
38+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
39+
} {
40+
func.func @main() {
41+
%c1 = arith.constant 1 : index
42+
%c100 = arith.constant 100: i32
43+
%cst_f32 = arith.constant 314.4: f32
44+
45+
gpu.launch_func @kernels1::@printf_args
46+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
47+
args(%c100: i32, %cst_f32: f32)
48+
return
49+
}
50+
51+
gpu.module @kernels1 {
52+
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
53+
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
54+
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
55+
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
56+
gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
57+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
58+
%0 = gpu.block_id x
59+
%1 = gpu.block_id y
60+
%2 = gpu.thread_id x
61+
62+
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
63+
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
64+
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]], {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, f32, i32)) -> i32
65+
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
66+
67+
// CHECK: spirv.Return
68+
gpu.return
69+
}
70+
}
71+
}

mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
275275
//===----------------------------------------------------------------------===//
276276
// CHECK-LABEL: func.func @printf(
277277
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
278-
// CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
279-
%0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
278+
// CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
279+
%0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
280280
return %0 : i32
281281
}
282282

0 commit comments

Comments
 (0)