@@ -123,7 +123,7 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
123
123
124
124
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125
125
public:
126
- using OpConversionPattern<gpu::PrintfOp> ::OpConversionPattern;
126
+ using OpConversionPattern::OpConversionPattern;
127
127
128
128
LogicalResult
129
129
matchAndRewrite (gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
@@ -606,106 +606,118 @@ class GPUSubgroupReduceConversion final
606
606
}
607
607
};
608
608
609
+ // / Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
610
+
609
611
LogicalResult GPUPrintfConversion::matchAndRewrite (
610
612
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
611
613
ConversionPatternRewriter &rewriter) const {
612
614
613
- auto loc = gpuPrintfOp.getLoc ();
615
+ Location loc = gpuPrintfOp.getLoc ();
614
616
615
- auto funcOp =
616
- rewriter.getBlock ()->getParent ()->getParentOfType <mlir::spirv::FuncOp>();
617
+ auto moduleOp = gpuPrintfOp->getParentOfType <spirv::ModuleOp>();
617
618
618
- auto moduleOp = funcOp->getParentOfType <mlir::spirv::ModuleOp>();
619
+ if (!moduleOp) {
620
+ return success ();
621
+ }
619
622
620
623
const char formatStringPrefix[] = " printfMsg" ;
621
624
unsigned stringNumber = 0 ;
622
- mlir::SmallString<16 > globalVarName;
623
- mlir::spirv::GlobalVariableOp globalVar;
624
-
625
- // formulate spirv global variable name
625
+ SmallString<16 > globalVarName;
626
+ spirv::GlobalVariableOp globalVar;
627
+
628
+ // SPIR-V global variable is used to initialize printf
629
+ // format string value, if there are multiple printf messages,
630
+ // each global var needs to be created with a unique name.
631
+ // like printfMsg0, printfMsg1, ...
632
+ // Formulate unique global variable name after
633
+ // searching in the module for existing global variable names.
634
+ // This is to avoid name collision with existing global variables.
626
635
do {
627
636
globalVarName.clear ();
628
637
(formatStringPrefix + llvm::Twine (stringNumber++))
629
638
.toStringRef (globalVarName);
630
639
} while (moduleOp.lookupSymbol (globalVarName));
631
640
632
- auto i8Type = rewriter.getI8Type ();
633
- auto i32Type = rewriter.getI32Type ();
641
+ IntegerType i8Type = rewriter.getI8Type ();
642
+ IntegerType i32Type = rewriter.getI32Type ();
634
643
635
- unsigned scNum = 0 ;
644
+ // Each character of printf format string is
645
+ // stored as a spec constant. We need to create
646
+ // unique name for this spec constant like
647
+ // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
648
+ // for existing spec constant names.
649
+ unsigned specConstantNum = 0 ;
636
650
auto createSpecConstant = [&](unsigned value) {
637
651
auto attr = rewriter.getI8IntegerAttr (value);
638
- mlir:: SmallString<16 > specCstName;
639
- (llvm::Twine (globalVarName) + " _sc" + llvm::Twine (scNum ++))
652
+ SmallString<16 > specCstName;
653
+ (llvm::Twine (globalVarName) + " _sc" + llvm::Twine (specConstantNum ++))
640
654
.toStringRef (specCstName);
641
655
642
- return rewriter.create <mlir:: spirv::SpecConstantOp>(
656
+ return rewriter.create <spirv::SpecConstantOp>(
643
657
loc, rewriter.getStringAttr (specCstName), attr);
644
658
};
645
-
646
- // define GlobalVarOp with printf format string using SpecConstants
647
- // and make composite of SpecConstants
648
659
{
649
- mlir:: Operation *parent =
650
- mlir:: SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
660
+ Operation *parent =
661
+ SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
651
662
652
- mlir:: ConversionPatternRewriter::InsertionGuard guard (rewriter);
663
+ ConversionPatternRewriter::InsertionGuard guard (rewriter);
653
664
654
- mlir:: Block &entryBlock = *parent->getRegion (0 ).begin ();
665
+ Block &entryBlock = *parent->getRegion (0 ).begin ();
655
666
rewriter.setInsertionPointToStart (
656
667
&entryBlock); // insertion point at module level
657
668
658
- // Create Constituents with SpecConstant to construct
669
+ // Create Constituents with SpecConstant by scanning format string
670
+ // Each character of format string is stored as a spec constant
671
+ // and then these spec constants are used to create a
659
672
// SpecConstantCompositeOp
660
- llvm::SmallString<20 > formatString (gpuPrintfOp .getFormat ());
673
+ llvm::SmallString<20 > formatString (adaptor .getFormat ());
661
674
formatString.push_back (' \0 ' ); // Null terminate for C
662
- mlir:: SmallVector<mlir:: Attribute, 4 > constituents;
675
+ SmallVector<Attribute, 4 > constituents;
663
676
for (auto c : formatString) {
664
677
auto cSpecConstantOp = createSpecConstant (c);
665
- constituents.push_back (mlir:: SymbolRefAttr::get (cSpecConstantOp));
678
+ constituents.push_back (SymbolRefAttr::get (cSpecConstantOp));
666
679
}
667
680
668
- // Create specialization constant composite defined via spirv.SpecConstant
681
+ // Create SpecConstantCompositeOp to initialize the global variable
669
682
size_t contentSize = constituents.size ();
670
- auto globalType = mlir:: spirv::ArrayType::get (i8Type, contentSize);
671
- mlir:: spirv::SpecConstantCompositeOp specCstComposite;
672
- mlir:: SmallString<16 > specCstCompositeName;
683
+ auto globalType = spirv::ArrayType::get (i8Type, contentSize);
684
+ spirv::SpecConstantCompositeOp specCstComposite;
685
+ SmallString<16 > specCstCompositeName;
673
686
(llvm::Twine (globalVarName) + " _scc" ).toStringRef (specCstCompositeName);
674
- specCstComposite = rewriter.create <mlir:: spirv::SpecConstantCompositeOp>(
675
- loc, mlir:: TypeAttr::get (globalType),
687
+ specCstComposite = rewriter.create <spirv::SpecConstantCompositeOp>(
688
+ loc, TypeAttr::get (globalType),
676
689
rewriter.getStringAttr (specCstCompositeName),
677
690
rewriter.getArrayAttr (constituents));
678
691
679
- // Define GlobalVariable initialized from Constant Composite
680
- globalVar = rewriter.create <mlir::spirv::GlobalVariableOp>(
681
- loc,
682
- mlir::spirv::PointerType::get (
683
- globalType, mlir::spirv::StorageClass::UniformConstant),
684
- globalVarName, mlir::FlatSymbolRefAttr::get (specCstComposite));
692
+ auto ptrType = spirv::PointerType::get (
693
+ globalType, spirv::StorageClass::UniformConstant);
694
+
695
+ // Define a GlobalVarOp initialized using specialized constants
696
+ // that is used to specify the printf format string
697
+ // to be passed to the SPIRV CLPrintfOp.
698
+ globalVar = rewriter.create <spirv::GlobalVariableOp>(
699
+ loc, ptrType, globalVarName, FlatSymbolRefAttr::get (specCstComposite));
685
700
686
701
globalVar->setAttr (" Constant" , rewriter.getUnitAttr ());
687
702
}
688
-
689
703
// Get SSA value of Global variable
690
- mlir::Value globalPtr =
691
- rewriter.create <mlir::spirv::AddressOfOp>(loc, globalVar);
692
- mlir::Value fmtStr = rewriter.create <mlir::spirv::BitcastOp>(
704
+ Value globalPtr = rewriter.create <spirv::AddressOfOp>(loc, globalVar);
705
+ Value fmtStr = rewriter.create <spirv::BitcastOp>(
693
706
loc,
694
- mlir::spirv::PointerType::get (i8Type,
695
- mlir::spirv::StorageClass::UniformConstant),
707
+ spirv::PointerType::get (i8Type, spirv::StorageClass::UniformConstant),
696
708
globalPtr);
697
709
698
710
// Get printf arguments
699
711
auto argsRange = adaptor.getArgs ();
700
- mlir:: SmallVector<mlir:: Value, 4 > printfArgs;
712
+ SmallVector<Value, 4 > printfArgs;
701
713
printfArgs.reserve (argsRange.size () + 1 );
702
714
printfArgs.append (argsRange.begin (), argsRange.end ());
703
715
704
- rewriter.create <mlir:: spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
716
+ rewriter.create <spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
705
717
706
718
rewriter.eraseOp (gpuPrintfOp);
707
719
708
- return mlir:: success ();
720
+ return success ();
709
721
}
710
722
711
723
// ===----------------------------------------------------------------------===//
0 commit comments