Skip to content

Commit b0e9c5f

Browse files
committed
address feedback
1 parent 7b1f78e commit b0e9c5f

File tree

3 files changed

+66
-52
lines changed

3 files changed

+66
-52
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,9 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
875875
#### Example:
876876

877877
```mlir
878-
%0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
878+
%0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
879+
880+
$format `,` ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
879881
```
880882
}];
881883

@@ -889,7 +891,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
889891
);
890892

891893
let assemblyFormat = [{
892-
$format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
894+
$format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
893895
}];
894896

895897
let hasVerifier = 0;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
123123

124124
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125125
public:
126-
using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
126+
using OpConversionPattern::OpConversionPattern;
127127

128128
LogicalResult
129129
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
@@ -606,106 +606,118 @@ class GPUSubgroupReduceConversion final
606606
}
607607
};
608608

609+
/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
610+
609611
LogicalResult GPUPrintfConversion::matchAndRewrite(
610612
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
611613
ConversionPatternRewriter &rewriter) const {
612614

613-
auto loc = gpuPrintfOp.getLoc();
615+
Location loc = gpuPrintfOp.getLoc();
614616

615-
auto funcOp =
616-
rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
617+
auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
617618

618-
auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
619+
if (!moduleOp) {
620+
return success();
621+
}
619622

620623
const char formatStringPrefix[] = "printfMsg";
621624
unsigned stringNumber = 0;
622-
mlir::SmallString<16> globalVarName;
623-
mlir::spirv::GlobalVariableOp globalVar;
624-
625-
// formulate spirv global variable name
625+
SmallString<16> globalVarName;
626+
spirv::GlobalVariableOp globalVar;
627+
628+
// SPIR-V global variable is used to initialize printf
629+
// format string value, if there are multiple printf messages,
630+
// each global var needs to be created with a unique name.
631+
// like printfMsg0, printfMsg1, ...
632+
// Formulate unique global variable name after
633+
// searching in the module for existing global variable names.
634+
// This is to avoid name collision with existing global variables.
626635
do {
627636
globalVarName.clear();
628637
(formatStringPrefix + llvm::Twine(stringNumber++))
629638
.toStringRef(globalVarName);
630639
} while (moduleOp.lookupSymbol(globalVarName));
631640

632-
auto i8Type = rewriter.getI8Type();
633-
auto i32Type = rewriter.getI32Type();
641+
IntegerType i8Type = rewriter.getI8Type();
642+
IntegerType i32Type = rewriter.getI32Type();
634643

635-
unsigned scNum = 0;
644+
// Each character of printf format string is
645+
// stored as a spec constant. We need to create
646+
// unique name for this spec constant like
647+
// @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
648+
// for existing spec constant names.
649+
unsigned specConstantNum = 0;
636650
auto createSpecConstant = [&](unsigned value) {
637651
auto attr = rewriter.getI8IntegerAttr(value);
638-
mlir::SmallString<16> specCstName;
639-
(llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
652+
SmallString<16> specCstName;
653+
(llvm::Twine(globalVarName) + "_sc" + llvm::Twine(specConstantNum++))
640654
.toStringRef(specCstName);
641655

642-
return rewriter.create<mlir::spirv::SpecConstantOp>(
656+
return rewriter.create<spirv::SpecConstantOp>(
643657
loc, rewriter.getStringAttr(specCstName), attr);
644658
};
645-
646-
// define GlobalVarOp with printf format string using SpecConstants
647-
// and make composite of SpecConstants
648659
{
649-
mlir::Operation *parent =
650-
mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
660+
Operation *parent =
661+
SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
651662

652-
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
663+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
653664

654-
mlir::Block &entryBlock = *parent->getRegion(0).begin();
665+
Block &entryBlock = *parent->getRegion(0).begin();
655666
rewriter.setInsertionPointToStart(
656667
&entryBlock); // insertion point at module level
657668

658-
// Create Constituents with SpecConstant to construct
669+
// Create Constituents with SpecConstant by scanning format string
670+
// Each character of format string is stored as a spec constant
671+
// and then these spec constants are used to create a
659672
// SpecConstantCompositeOp
660-
llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
673+
llvm::SmallString<20> formatString(adaptor.getFormat());
661674
formatString.push_back('\0'); // Null terminate for C
662-
mlir::SmallVector<mlir::Attribute, 4> constituents;
675+
SmallVector<Attribute, 4> constituents;
663676
for (auto c : formatString) {
664677
auto cSpecConstantOp = createSpecConstant(c);
665-
constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
678+
constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
666679
}
667680

668-
// Create specialization constant composite defined via spirv.SpecConstant
681+
// Create SpecConstantCompositeOp to initialize the global variable
669682
size_t contentSize = constituents.size();
670-
auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
671-
mlir::spirv::SpecConstantCompositeOp specCstComposite;
672-
mlir::SmallString<16> specCstCompositeName;
683+
auto globalType = spirv::ArrayType::get(i8Type, contentSize);
684+
spirv::SpecConstantCompositeOp specCstComposite;
685+
SmallString<16> specCstCompositeName;
673686
(llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
674-
specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
675-
loc, mlir::TypeAttr::get(globalType),
687+
specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
688+
loc, TypeAttr::get(globalType),
676689
rewriter.getStringAttr(specCstCompositeName),
677690
rewriter.getArrayAttr(constituents));
678691

679-
// Define GlobalVariable initialized from Constant Composite
680-
globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
681-
loc,
682-
mlir::spirv::PointerType::get(
683-
globalType, mlir::spirv::StorageClass::UniformConstant),
684-
globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
692+
auto ptrType = spirv::PointerType::get(
693+
globalType, spirv::StorageClass::UniformConstant);
694+
695+
// Define a GlobalVarOp initialized using specialized constants
696+
// that is used to specify the printf format string
697+
// to be passed to the SPIRV CLPrintfOp.
698+
globalVar = rewriter.create<spirv::GlobalVariableOp>(
699+
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
685700

686701
globalVar->setAttr("Constant", rewriter.getUnitAttr());
687702
}
688-
689703
// Get SSA value of Global variable
690-
mlir::Value globalPtr =
691-
rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
692-
mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
704+
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
705+
Value fmtStr = rewriter.create<spirv::BitcastOp>(
693706
loc,
694-
mlir::spirv::PointerType::get(i8Type,
695-
mlir::spirv::StorageClass::UniformConstant),
707+
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
696708
globalPtr);
697709

698710
// Get printf arguments
699711
auto argsRange = adaptor.getArgs();
700-
mlir::SmallVector<mlir::Value, 4> printfArgs;
712+
SmallVector<Value, 4> printfArgs;
701713
printfArgs.reserve(argsRange.size() + 1);
702714
printfArgs.append(argsRange.begin(), argsRange.end());
703715

704-
rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
716+
rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
705717

706718
rewriter.eraseOp(gpuPrintfOp);
707719

708-
return mlir::success();
720+
return success();
709721
}
710722

711723
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
274274
// spirv.CL.printf
275275
//===----------------------------------------------------------------------===//
276276
// CHECK-LABEL: func.func @printf(
277-
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
278-
// CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
279-
%0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
277+
func.func @printf(%fmt : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
278+
// CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
279+
%0 = spirv.CL.printf %fmt, %arg1, %arg2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
280280
return %0 : i32
281281
}
282282

0 commit comments

Comments
 (0)