@@ -135,6 +135,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
135
135
ConversionPatternRewriter &rewriter) const override ;
136
136
};
137
137
138
+ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
139
+ public:
140
+ using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
141
+
142
+ LogicalResult
143
+ matchAndRewrite (gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
144
+ ConversionPatternRewriter &rewriter) const override ;
145
+ };
146
+
138
147
} // namespace
139
148
140
149
// ===----------------------------------------------------------------------===//
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
607
616
}
608
617
};
609
618
619
+ LogicalResult GPUPrintfConversion::matchAndRewrite (
620
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
621
+ ConversionPatternRewriter &rewriter) const {
622
+
623
+ auto loc = gpuPrintfOp.getLoc ();
624
+
625
+ auto funcOp =
626
+ rewriter.getBlock ()->getParent ()->getParentOfType <mlir::spirv::FuncOp>();
627
+
628
+ auto moduleOp = funcOp->getParentOfType <mlir::spirv::ModuleOp>();
629
+
630
+ const char formatStringPrefix[] = " printfMsg" ;
631
+ unsigned stringNumber = 0 ;
632
+ mlir::SmallString<16 > globalVarName;
633
+ mlir::spirv::GlobalVariableOp globalVar;
634
+
635
+ // formulate spirv global variable name
636
+ do {
637
+ globalVarName.clear ();
638
+ (formatStringPrefix + llvm::Twine (stringNumber++))
639
+ .toStringRef (globalVarName);
640
+ } while (moduleOp.lookupSymbol (globalVarName));
641
+
642
+ auto i8Type = rewriter.getI8Type ();
643
+ auto i32Type = rewriter.getI32Type ();
644
+
645
+ unsigned scNum = 0 ;
646
+ auto createSpecConstant = [&](unsigned value) {
647
+ auto attr = rewriter.getI8IntegerAttr (value);
648
+ mlir::SmallString<16 > specCstName;
649
+ (llvm::Twine (globalVarName) + " _sc" + llvm::Twine (scNum++))
650
+ .toStringRef (specCstName);
651
+
652
+ return rewriter.create <mlir::spirv::SpecConstantOp>(
653
+ loc, rewriter.getStringAttr (specCstName), attr);
654
+ };
655
+
656
+ // define GlobalVarOp with printf format string using SpecConstants
657
+ // and make composite of SpecConstants
658
+ {
659
+ mlir::Operation *parent =
660
+ mlir::SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
661
+
662
+ mlir::ConversionPatternRewriter::InsertionGuard guard (rewriter);
663
+
664
+ mlir::Block &entryBlock = *parent->getRegion (0 ).begin ();
665
+ rewriter.setInsertionPointToStart (
666
+ &entryBlock); // insertion point at module level
667
+
668
+ // Create Constituents with SpecConstant to construct
669
+ // SpecConstantCompositeOp
670
+ llvm::SmallString<20 > formatString (gpuPrintfOp.getFormat ());
671
+ formatString.push_back (' \0 ' ); // Null terminate for C
672
+ mlir::SmallVector<mlir::Attribute, 4 > constituents;
673
+ for (auto c : formatString) {
674
+ auto cSpecConstantOp = createSpecConstant (c);
675
+ constituents.push_back (mlir::SymbolRefAttr::get (cSpecConstantOp));
676
+ }
677
+
678
+ // Create specialization constant composite defined via spirv.SpecConstant
679
+ size_t contentSize = constituents.size ();
680
+ auto globalType = mlir::spirv::ArrayType::get (i8Type, contentSize);
681
+ mlir::spirv::SpecConstantCompositeOp specCstComposite;
682
+ mlir::SmallString<16 > specCstCompositeName;
683
+ (llvm::Twine (globalVarName) + " _scc" ).toStringRef (specCstCompositeName);
684
+ specCstComposite = rewriter.create <mlir::spirv::SpecConstantCompositeOp>(
685
+ loc, mlir::TypeAttr::get (globalType),
686
+ rewriter.getStringAttr (specCstCompositeName),
687
+ rewriter.getArrayAttr (constituents));
688
+
689
+ // Define GlobalVariable initialized from Constant Composite
690
+ globalVar = rewriter.create <mlir::spirv::GlobalVariableOp>(
691
+ loc,
692
+ mlir::spirv::PointerType::get (
693
+ globalType, mlir::spirv::StorageClass::UniformConstant),
694
+ globalVarName, mlir::FlatSymbolRefAttr::get (specCstComposite));
695
+
696
+ globalVar->setAttr (" Constant" , rewriter.getUnitAttr ());
697
+ }
698
+
699
+ // Get SSA value of Global variable
700
+ mlir::Value globalPtr =
701
+ rewriter.create <mlir::spirv::AddressOfOp>(loc, globalVar);
702
+ mlir::Value fmtStr = rewriter.create <mlir::spirv::BitcastOp>(
703
+ loc,
704
+ mlir::spirv::PointerType::get (i8Type,
705
+ mlir::spirv::StorageClass::UniformConstant),
706
+ globalPtr);
707
+
708
+ // Get printf arguments
709
+ auto argsRange = adaptor.getArgs ();
710
+ mlir::SmallVector<mlir::Value, 4 > printfArgs;
711
+ printfArgs.reserve (argsRange.size () + 1 );
712
+ printfArgs.append (argsRange.begin (), argsRange.end ());
713
+
714
+ rewriter.create <mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
715
+
716
+ rewriter.eraseOp (gpuPrintfOp);
717
+
718
+ return mlir::success ();
719
+ }
720
+
610
721
// ===----------------------------------------------------------------------===//
611
722
// GPU To SPIRV Patterns.
612
723
// ===----------------------------------------------------------------------===//
@@ -630,5 +741,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
630
741
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
631
742
spirv::BuiltIn::SubgroupSize>,
632
743
WorkGroupSizeConversion, GPUAllReduceConversion,
633
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext ());
744
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
745
+ patterns.getContext ());
634
746
}
0 commit comments