@@ -550,6 +550,145 @@ static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
550
550
551
551
static LogicalResult verify (IndexedGenericOp op) { return verifyGenericOp (op); }
552
552
553
+ // ===----------------------------------------------------------------------===//
554
+ // InitTensorOp
555
+ // ===----------------------------------------------------------------------===//
556
+
557
+ static ParseResult parseInitTensorOp (OpAsmParser &parser,
558
+ OperationState &result) {
559
+ OpAsmParser::OperandType srcInfo;
560
+ Type dstType;
561
+ SmallVector<OpAsmParser::OperandType, 2 > sizeInfo;
562
+ IndexType indexType = parser.getBuilder ().getIndexType ();
563
+ if (failed (parseListOfOperandsOrIntegers (
564
+ parser, result, InitTensorOp::getStaticSizesAttrName (),
565
+ ShapedType::kDynamicSize , sizeInfo)) ||
566
+ failed (parser.parseOptionalAttrDict (result.attributes )) ||
567
+ failed (parser.parseColonType (dstType)) ||
568
+ failed (parser.resolveOperands (sizeInfo, indexType, result.operands )))
569
+ return failure ();
570
+ return parser.addTypeToList (dstType, result.types );
571
+ }
572
+
573
+ static void print (OpAsmPrinter &p, InitTensorOp op) {
574
+ p << op.getOperation ()->getName () << ' ' ;
575
+ printListOfOperandsOrIntegers (p, op.sizes (), op.static_sizes (),
576
+ ShapedType::isDynamic);
577
+ p.printOptionalAttrDict (op.getAttrs (),
578
+ InitTensorOp::getStaticSizesAttrName ());
579
+ p << " : " << op.getType ();
580
+ }
581
+
582
+ static LogicalResult verify (InitTensorOp op) {
583
+ RankedTensorType resultType = op.getType ();
584
+ SmallVector<int64_t , 4 > staticSizes = llvm::to_vector<4 >(llvm::map_range (
585
+ op.static_sizes ().cast <ArrayAttr>(),
586
+ [](Attribute a) -> int64_t { return a.cast <IntegerAttr>().getInt (); }));
587
+
588
+ if (failed (verifyListOfOperandsOrIntegers (op, " sizes" , resultType.getRank (),
589
+ op.static_sizes (), op.sizes (),
590
+ ShapedType::isDynamic)))
591
+ return failure ();
592
+
593
+ Type expectedType =
594
+ InitTensorOp::inferResultType (staticSizes, resultType.getElementType ());
595
+ if (resultType != expectedType) {
596
+ return op.emitError (" specified type " )
597
+ << resultType << " does not match the inferred type "
598
+ << expectedType;
599
+ }
600
+ return success ();
601
+ }
602
+
603
+ Type InitTensorOp::inferResultType (ArrayRef<int64_t > staticSizes,
604
+ Type elementType) {
605
+ return RankedTensorType::get (staticSizes, elementType);
606
+ }
607
+
608
+ namespace {
609
+ // / Change the type of the result of a `linalg.init_tensor` by making the result
610
+ // / type statically sized along dimension that in the original operation where
611
+ // / defined as dynamic, but the size was defined using a `constant` op. For
612
+ // / example
613
+ // /
614
+ // / %c5 = constant 5: index
615
+ // / %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
616
+ // /
617
+ // / to
618
+ // /
619
+ // / %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
620
+ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
621
+ using OpRewritePattern<InitTensorOp>::OpRewritePattern;
622
+
623
+ LogicalResult matchAndRewrite (InitTensorOp op,
624
+ PatternRewriter &rewriter) const override {
625
+ SmallVector<Value, 4 > dynamicSizes;
626
+ SmallVector<int64_t , 4 > staticSizes;
627
+ for (unsigned i = 0 , e = op.getType ().getRank (); i != e; ++i) {
628
+ // If the size is already static, nothing to do.
629
+ if (!op.isDynamicSize (i)) {
630
+ staticSizes.push_back (op.getStaticSize (i));
631
+ continue ;
632
+ }
633
+
634
+ // If the size is dynamic but defined using a `constant` op, get the
635
+ // constant value to find the static size to use.
636
+ unsigned operandNum = op.getIndexOfDynamicSize (i);
637
+ Value sizeOperand = op.getOperand (operandNum);
638
+ if (auto constantIndexOp = sizeOperand.getDefiningOp <ConstantIndexOp>()) {
639
+ staticSizes.push_back (constantIndexOp.getValue ());
640
+ continue ;
641
+ }
642
+
643
+ // Fallback case. Keep the size dynamic.
644
+ dynamicSizes.push_back (sizeOperand);
645
+ staticSizes.push_back (ShapedType::kDynamicSize );
646
+ }
647
+ RankedTensorType newType =
648
+ RankedTensorType::get (staticSizes, op.getType ().getElementType ());
649
+ if (newType == op.getType ())
650
+ return failure ();
651
+ auto newOp =
652
+ rewriter.create <InitTensorOp>(op.getLoc (), newType, dynamicSizes,
653
+ rewriter.getI64ArrayAttr (staticSizes));
654
+ rewriter.replaceOpWithNewOp <TensorCastOp>(op, op.getType (), newOp);
655
+ return success ();
656
+ }
657
+ };
658
+
659
+ // / Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
660
+ // / with
661
+ // / - A constant value if the size is static along the dimension.
662
+ // / - The dynamic value that defines the size of the result of
663
+ // / `linalg.init_tensor` op.
664
+ struct ReplaceDimOfInitTensorOp : public OpRewritePattern <DimOp> {
665
+ using OpRewritePattern<DimOp>::OpRewritePattern;
666
+
667
+ LogicalResult matchAndRewrite (DimOp dimOp,
668
+ PatternRewriter &rewriter) const override {
669
+ auto initTensorOp = dimOp.memrefOrTensor ().getDefiningOp <InitTensorOp>();
670
+ if (!initTensorOp)
671
+ return failure ();
672
+ auto dimIndex = dimOp.index ().getDefiningOp <ConstantIndexOp>();
673
+ if (!dimIndex)
674
+ return failure ();
675
+ int64_t index = dimIndex.getValue ();
676
+ if (!initTensorOp.isDynamicSize (index)) {
677
+ rewriter.replaceOpWithNewOp <ConstantIndexOp>(
678
+ dimOp, initTensorOp.getStaticSize (index));
679
+ } else {
680
+ rewriter.replaceOp (dimOp, initTensorOp.getDynamicSize (index));
681
+ }
682
+ return success ();
683
+ }
684
+ };
685
+ } // namespace
686
+
687
+ void InitTensorOp::getCanonicalizationPatterns (
688
+ OwningRewritePatternList &results, MLIRContext *context) {
689
+ results.insert <ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
690
+ }
691
+
553
692
// ===----------------------------------------------------------------------===//
554
693
// ReshapeOp
555
694
// ===----------------------------------------------------------------------===//
0 commit comments