@@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
121
121
ConversionPatternRewriter &rewriter) const override ;
122
122
};
123
123
124
+ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125
+ public:
126
+ using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
127
+
128
+ LogicalResult
129
+ matchAndRewrite (gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
130
+ ConversionPatternRewriter &rewriter) const override ;
131
+ };
132
+
124
133
} // namespace
125
134
126
135
// ===----------------------------------------------------------------------===//
@@ -597,6 +606,108 @@ class GPUSubgroupReduceConversion final
597
606
}
598
607
};
599
608
609
+ LogicalResult GPUPrintfConversion::matchAndRewrite (
610
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
611
+ ConversionPatternRewriter &rewriter) const {
612
+
613
+ auto loc = gpuPrintfOp.getLoc ();
614
+
615
+ auto funcOp =
616
+ rewriter.getBlock ()->getParent ()->getParentOfType <mlir::spirv::FuncOp>();
617
+
618
+ auto moduleOp = funcOp->getParentOfType <mlir::spirv::ModuleOp>();
619
+
620
+ const char formatStringPrefix[] = " printfMsg" ;
621
+ unsigned stringNumber = 0 ;
622
+ mlir::SmallString<16 > globalVarName;
623
+ mlir::spirv::GlobalVariableOp globalVar;
624
+
625
+ // formulate spirv global variable name
626
+ do {
627
+ globalVarName.clear ();
628
+ (formatStringPrefix + llvm::Twine (stringNumber++))
629
+ .toStringRef (globalVarName);
630
+ } while (moduleOp.lookupSymbol (globalVarName));
631
+
632
+ auto i8Type = rewriter.getI8Type ();
633
+ auto i32Type = rewriter.getI32Type ();
634
+
635
+ unsigned scNum = 0 ;
636
+ auto createSpecConstant = [&](unsigned value) {
637
+ auto attr = rewriter.getI8IntegerAttr (value);
638
+ mlir::SmallString<16 > specCstName;
639
+ (llvm::Twine (globalVarName) + " _sc" + llvm::Twine (scNum++))
640
+ .toStringRef (specCstName);
641
+
642
+ return rewriter.create <mlir::spirv::SpecConstantOp>(
643
+ loc, rewriter.getStringAttr (specCstName), attr);
644
+ };
645
+
646
+ // define GlobalVarOp with printf format string using SpecConstants
647
+ // and make composite of SpecConstants
648
+ {
649
+ mlir::Operation *parent =
650
+ mlir::SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
651
+
652
+ mlir::ConversionPatternRewriter::InsertionGuard guard (rewriter);
653
+
654
+ mlir::Block &entryBlock = *parent->getRegion (0 ).begin ();
655
+ rewriter.setInsertionPointToStart (
656
+ &entryBlock); // insertion point at module level
657
+
658
+ // Create Constituents with SpecConstant to construct
659
+ // SpecConstantCompositeOp
660
+ llvm::SmallString<20 > formatString (gpuPrintfOp.getFormat ());
661
+ formatString.push_back (' \0 ' ); // Null terminate for C
662
+ mlir::SmallVector<mlir::Attribute, 4 > constituents;
663
+ for (auto c : formatString) {
664
+ auto cSpecConstantOp = createSpecConstant (c);
665
+ constituents.push_back (mlir::SymbolRefAttr::get (cSpecConstantOp));
666
+ }
667
+
668
+ // Create specialization constant composite defined via spirv.SpecConstant
669
+ size_t contentSize = constituents.size ();
670
+ auto globalType = mlir::spirv::ArrayType::get (i8Type, contentSize);
671
+ mlir::spirv::SpecConstantCompositeOp specCstComposite;
672
+ mlir::SmallString<16 > specCstCompositeName;
673
+ (llvm::Twine (globalVarName) + " _scc" ).toStringRef (specCstCompositeName);
674
+ specCstComposite = rewriter.create <mlir::spirv::SpecConstantCompositeOp>(
675
+ loc, mlir::TypeAttr::get (globalType),
676
+ rewriter.getStringAttr (specCstCompositeName),
677
+ rewriter.getArrayAttr (constituents));
678
+
679
+ // Define GlobalVariable initialized from Constant Composite
680
+ globalVar = rewriter.create <mlir::spirv::GlobalVariableOp>(
681
+ loc,
682
+ mlir::spirv::PointerType::get (
683
+ globalType, mlir::spirv::StorageClass::UniformConstant),
684
+ globalVarName, mlir::FlatSymbolRefAttr::get (specCstComposite));
685
+
686
+ globalVar->setAttr (" Constant" , rewriter.getUnitAttr ());
687
+ }
688
+
689
+ // Get SSA value of Global variable
690
+ mlir::Value globalPtr =
691
+ rewriter.create <mlir::spirv::AddressOfOp>(loc, globalVar);
692
+ mlir::Value fmtStr = rewriter.create <mlir::spirv::BitcastOp>(
693
+ loc,
694
+ mlir::spirv::PointerType::get (i8Type,
695
+ mlir::spirv::StorageClass::UniformConstant),
696
+ globalPtr);
697
+
698
+ // Get printf arguments
699
+ auto argsRange = adaptor.getArgs ();
700
+ mlir::SmallVector<mlir::Value, 4 > printfArgs;
701
+ printfArgs.reserve (argsRange.size () + 1 );
702
+ printfArgs.append (argsRange.begin (), argsRange.end ());
703
+
704
+ rewriter.create <mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
705
+
706
+ rewriter.eraseOp (gpuPrintfOp);
707
+
708
+ return mlir::success ();
709
+ }
710
+
600
711
// ===----------------------------------------------------------------------===//
601
712
// GPU To SPIRV Patterns.
602
713
// ===----------------------------------------------------------------------===//
@@ -620,5 +731,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
620
731
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
621
732
spirv::BuiltIn::SubgroupSize>,
622
733
WorkGroupSizeConversion, GPUAllReduceConversion,
623
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext ());
734
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
735
+ patterns.getContext ());
624
736
}
0 commit comments