23
23
#include " mlir/IR/Value.h"
24
24
#include " mlir/Support/MathExtras.h"
25
25
#include " mlir/Transforms/InliningUtils.h"
26
+ #include " llvm/ADT/STLExtras.h"
26
27
#include " llvm/ADT/StringSwitch.h"
27
28
#include " llvm/Support/FormatVariadic.h"
28
29
#include " llvm/Support/raw_ostream.h"
@@ -2639,10 +2640,15 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2639
2640
// / `:` strided-memref-type `to` strided-memref-type
2640
2641
// / ```
2641
2642
template <typename OpType>
2642
- static void printOpWithOffsetsSizesAndStrides (OpAsmPrinter &p, OpType op) {
2643
+ static void printOpWithOffsetsSizesAndStrides (
2644
+ OpAsmPrinter &p, OpType op,
2645
+ llvm::function_ref<void (OpAsmPrinter &p, OpType op)> printExtraOperands =
2646
+ [](OpAsmPrinter &p, OpType op) {},
2647
+ StringLiteral resultTypeKeyword = " to" ) {
2643
2648
int stdDotLen = StandardOpsDialect::getDialectNamespace ().size () + 1 ;
2644
2649
p << op.getOperation ()->getName ().getStringRef ().drop_front (stdDotLen) << ' ' ;
2645
- p << op.getOperand (0 );
2650
+ p << op.source ();
2651
+ printExtraOperands (p, op);
2646
2652
printSubViewListOfOperandsOrIntegers (p, op.offsets (), op.static_offsets (),
2647
2653
ShapedType::isDynamicStrideOrOffset);
2648
2654
printSubViewListOfOperandsOrIntegers (p, op.sizes (), op.static_sizes (),
@@ -2651,27 +2657,35 @@ static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
2651
2657
ShapedType::isDynamicStrideOrOffset);
2652
2658
p.printOptionalAttrDict (op.getAttrs (),
2653
2659
/* elidedAttrs=*/ {OpType::getSpecialAttrNames ()});
2654
- p << " : " << op.getOperand (0 ).getType () << " to " << op.getType ();
2660
+ p << " : " << op.getSourceType () << " " << resultTypeKeyword << " "
2661
+ << op.getType ();
2655
2662
}
2656
2663
2657
2664
static void print (OpAsmPrinter &p, SubViewOp op) {
2658
2665
return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
2659
2666
}
2660
2667
2661
- // / Parse SubViewOp of the form:
2668
+ // / Parse of the form:
2662
2669
// / ```
2663
- // / `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2664
- // / `:` strided-memref-type `to` strided-memref-type
2670
+ // / `name` ssa-name (extra-operands)?
2671
+ // / `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2672
+ // / `:` strided-memref-type `resultTypeKeyword strided-memref-type
2665
2673
// / ```
2666
2674
template <typename OpType>
2667
- static ParseResult parseOpWithOffsetsSizesAndStrides (OpAsmParser &parser,
2668
- OperationState &result) {
2669
- OpAsmParser::OperandType srcInfo;
2675
+ static ParseResult parseOpWithOffsetsSizesAndStrides (
2676
+ OpAsmParser &parser, OperationState &result,
2677
+ std::function<ParseResult(OpAsmParser &p,
2678
+ OpAsmParser::OperandType &dstInfo)>
2679
+ parseExtraOperand = nullptr,
2680
+ StringLiteral resultTypeKeyword = "to") {
2681
+ OpAsmParser::OperandType srcInfo, dstInfo;
2670
2682
SmallVector<OpAsmParser::OperandType, 4 > offsetsInfo, sizesInfo, stridesInfo;
2671
2683
auto indexType = parser.getBuilder ().getIndexType ();
2672
2684
Type srcType, dstType;
2673
2685
if (parser.parseOperand (srcInfo))
2674
2686
return failure ();
2687
+ if (parseExtraOperand && parseExtraOperand (parser, dstInfo))
2688
+ return failure ();
2675
2689
if (parseListOfOperandsOrIntegers (
2676
2690
parser, result, OpType::getStaticOffsetsAttrName (),
2677
2691
ShapedType::kDynamicStrideOrOffset , offsetsInfo) ||
@@ -2683,21 +2697,27 @@ static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
2683
2697
ShapedType::kDynamicStrideOrOffset , stridesInfo))
2684
2698
return failure ();
2685
2699
2700
+ // Handle segment sizes.
2686
2701
auto b = parser.getBuilder ();
2687
- SmallVector<int , 4 > segmentSizes{1 , static_cast <int >(offsetsInfo.size ()),
2688
- static_cast <int >(sizesInfo.size ()),
2689
- static_cast <int >(stridesInfo.size ())};
2702
+ SmallVector<int , 4 > segmentSizes = {1 , static_cast <int >(offsetsInfo.size ()),
2703
+ static_cast <int >(sizesInfo.size ()),
2704
+ static_cast <int >(stridesInfo.size ())};
2705
+ // If we parse an extra operand it needs to appear in the segmentSizes
2706
+ if (parseExtraOperand)
2707
+ segmentSizes.insert (segmentSizes.begin (), 1 );
2690
2708
result.addAttribute (OpType::getOperandSegmentSizeAttr (),
2691
2709
b.getI32VectorAttr (segmentSizes));
2692
2710
2693
2711
return failure (
2694
2712
parser.parseOptionalAttrDict (result.attributes ) ||
2695
2713
parser.parseColonType (srcType) ||
2714
+ parser.parseKeywordType (resultTypeKeyword.str ().c_str (), dstType) ||
2696
2715
parser.resolveOperand (srcInfo, srcType, result.operands ) ||
2716
+ (parseExtraOperand &&
2717
+ parser.resolveOperand (dstInfo, dstType, result.operands )) ||
2697
2718
parser.resolveOperands (offsetsInfo, indexType, result.operands ) ||
2698
2719
parser.resolveOperands (sizesInfo, indexType, result.operands ) ||
2699
2720
parser.resolveOperands (stridesInfo, indexType, result.operands ) ||
2700
- parser.parseKeywordType (" to" , dstType) ||
2701
2721
parser.addTypeToList (dstType, result.types ));
2702
2722
}
2703
2723
@@ -2894,7 +2914,7 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
2894
2914
2895
2915
// / Verifier for SubViewOp.
2896
2916
static LogicalResult verify (SubViewOp op) {
2897
- MemRefType baseType = op.getSourceMemRefType ();
2917
+ MemRefType baseType = op.getSourceType ();
2898
2918
MemRefType subViewType = op.getType ();
2899
2919
2900
2920
// The base memref and the view memref should be in the same memory space.
@@ -3273,8 +3293,7 @@ static LogicalResult verify(SubTensorOp op) {
3273
3293
3274
3294
// Verify result type against inferred type.
3275
3295
auto expectedType = SubTensorOp::inferResultType (
3276
- op.getSourceRankedTensorType (),
3277
- extractFromI64ArrayAttr (op.static_offsets ()),
3296
+ op.getSourceType (), extractFromI64ArrayAttr (op.static_offsets ()),
3278
3297
extractFromI64ArrayAttr (op.static_sizes ()),
3279
3298
extractFromI64ArrayAttr (op.static_strides ()));
3280
3299
if (!isRankReducedType (expectedType, op.getType ()))
@@ -3291,6 +3310,72 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3291
3310
context);
3292
3311
}
3293
3312
3313
+ // ===----------------------------------------------------------------------===//
3314
+ // SubTensorInsertOp
3315
+ // ===----------------------------------------------------------------------===//
3316
+
3317
+ static void print (OpAsmPrinter &p, SubTensorInsertOp op) {
3318
+ return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
3319
+ p, op,
3320
+ [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest (); },
3321
+ /* resultTypeKeyword=*/ " into" );
3322
+ }
3323
+
3324
+ static ParseResult parseSubTensorInsertOp (OpAsmParser &parser,
3325
+ OperationState &result) {
3326
+ return parseOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
3327
+ parser, result,
3328
+ [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) {
3329
+ return failure (parser.parseKeyword (" into" ) ||
3330
+ parser.parseOperand (dstInfo));
3331
+ },
3332
+ " into" );
3333
+ }
3334
+
3335
+ void mlir::SubTensorInsertOp::build (
3336
+ OpBuilder &b, OperationState &result, Value source, Value dest,
3337
+ ArrayRef<int64_t > staticOffsets, ArrayRef<int64_t > staticSizes,
3338
+ ArrayRef<int64_t > staticStrides, ValueRange offsets, ValueRange sizes,
3339
+ ValueRange strides, ArrayRef<NamedAttribute> attrs) {
3340
+ build (b, result, dest.getType (), source, dest, offsets, sizes, strides,
3341
+ b.getI64ArrayAttr (staticOffsets), b.getI64ArrayAttr (staticSizes),
3342
+ b.getI64ArrayAttr (staticStrides));
3343
+ result.addAttributes (attrs);
3344
+ }
3345
+
3346
+ // / Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3347
+ // / and `staticStrides` are automatically filled with source-memref-rank
3348
+ // / sentinel values that encode dynamic entries.
3349
+ void mlir::SubTensorInsertOp::build (OpBuilder &b, OperationState &result,
3350
+ Value source, Value dest,
3351
+ ValueRange offsets, ValueRange sizes,
3352
+ ValueRange strides,
3353
+ ArrayRef<NamedAttribute> attrs) {
3354
+ auto sourceRankedTensorType = source.getType ().cast <RankedTensorType>();
3355
+ unsigned rank = sourceRankedTensorType.getRank ();
3356
+ SmallVector<int64_t , 4 > staticOffsetsVector (
3357
+ rank, ShapedType::kDynamicStrideOrOffset );
3358
+ SmallVector<int64_t , 4 > staticSizesVector (rank, ShapedType::kDynamicSize );
3359
+ SmallVector<int64_t , 4 > staticStridesVector (
3360
+ rank, ShapedType::kDynamicStrideOrOffset );
3361
+ build (b, result, source, dest, staticOffsetsVector, staticSizesVector,
3362
+ staticStridesVector, offsets, sizes, strides, attrs);
3363
+ }
3364
+
3365
+ SmallVector<Range, 8 > SubTensorInsertOp::getOrCreateRanges (OpBuilder &b,
3366
+ Location loc) {
3367
+ return ::getOrCreateRangesImpl (*this , b, loc);
3368
+ }
3369
+
3370
+ // / Verifier for SubViewOp.
3371
+ static LogicalResult verify (SubTensorInsertOp op) {
3372
+ if (failed (verifyOpWithOffsetSizesAndStrides (op)))
3373
+ return failure ();
3374
+ if (op.getType () != op.dest ().getType ())
3375
+ return op.emitError (" expected result type to be " ) << op.dest ().getType ();
3376
+ return success ();
3377
+ }
3378
+
3294
3379
// ===----------------------------------------------------------------------===//
3295
3380
// TensorCastOp
3296
3381
// ===----------------------------------------------------------------------===//
0 commit comments