Skip to content

Commit b40a525

Browse files
author
git apple-llvm automerger
committed
Merge commit 'f38a24e6f891' from apple/main into swift/next
2 parents 202eb0c + f38a24e commit b40a525

File tree

3 files changed

+235
-35
lines changed

3 files changed

+235
-35
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 116 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,15 +2922,20 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
29222922

29232923
The SubView operation supports the following arguments:
29242924

2925-
* Memref: the "base" memref on which to create a "view" memref.
2926-
* Offsets: memref-rank number of dynamic offsets or static integer
2927-
attributes into the "base" memref at which to create the "view"
2928-
memref.
2929-
* Sizes: memref-rank number of dynamic sizes or static integer attributes
2930-
which specify the sizes of the result "view" memref type.
2931-
* Strides: memref-rank number of dynamic strides or static integer
2932-
attributes that compose multiplicatively with the base memref
2933-
strides in each dimension.
2925+
* semref: the "base" memref on which to create a "view" memref.
2926+
* offsets: memref-rank number of offsets into the "base" memref at which to
2927+
create the "view" memref.
2928+
* sizes: memref-rank number of sizes which specify the sizes of the result
2929+
"view" memref type.
2930+
* strides: memref-rank number of strides that compose multiplicatively with
2931+
the base memref strides in each dimension.
2932+
2933+
The representation based on offsets, sizes and strides support a
2934+
partially-static specification via attributes specified through the
2935+
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
2936+
sentinel value ShapedType::kDynamicSize and
2937+
ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
2938+
a static value.
29342939

29352940
A subview operation may additionally reduce the rank of the resulting view
29362941
by removing dimensions that are statically known to be of size 1.
@@ -3076,7 +3081,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
30763081

30773082
let extraClassDeclaration = extraBaseClassDeclaration # [{
30783083
/// Returns the type of the base memref operand.
3079-
MemRefType getSourceMemRefType() {
3084+
MemRefType getSourceType() {
30803085
return source().getType().cast<MemRefType>();
30813086
}
30823087

@@ -3108,13 +3113,19 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
31083113
The subtensor operation supports the following arguments:
31093114

31103115
* tensor: the "base" tensor from which to extract a subtensor.
3111-
* offsets: tensor-rank number of dynamic offsets or static integer
3112-
attributes into the "base" tensor from which to extract the
3113-
subtensor.
3114-
* sizes: tensor-rank number of dynamic sizes or static integer attributes
3115-
which specify the sizes of the result tensor type.
3116-
* strides: tensor-rank number of dynamic strides or static integer
3117-
attributes specifying susampling in each dimension.
3116+
* offsets: tensor-rank number of offsets into the "base" tensor from which
3117+
to extract the subtensor.
3118+
* sizes: tensor-rank number of sizes which specify the sizes of the result
3119+
tensor type.
3120+
* strides: tensor-rank number of strides specifying subsampling in each
3121+
dimension.
3122+
3123+
The representation based on offsets, sizes and strides support a
3124+
partially-static specification via attributes specified through the
3125+
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
3126+
sentinel value ShapedType::kDynamicSize and
3127+
ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
3128+
a static value.
31183129

31193130
After buffer-allocation, the "subtensor" op is expected to lower into a
31203131
"subview" op.
@@ -3144,9 +3155,22 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
31443155
);
31453156
let results = (outs AnyRankedTensor:$result);
31463157

3158+
let builders = [
3159+
// Build a SubViewOp with mixed static and dynamic entries.
3160+
OpBuilder<
3161+
"Value source, ArrayRef<int64_t> staticOffsets, "
3162+
"ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
3163+
"ValueRange offsets, ValueRange sizes, ValueRange strides, "
3164+
"ArrayRef<NamedAttribute> attrs = {}">,
3165+
// Build a SubViewOp with all dynamic entries.
3166+
OpBuilder<
3167+
"Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
3168+
"ArrayRef<NamedAttribute> attrs = {}">
3169+
];
3170+
31473171
let extraClassDeclaration = extraBaseClassDeclaration # [{
31483172
/// Returns the type of the base tensor operand.
3149-
RankedTensorType getSourceRankedTensorType() {
3173+
RankedTensorType getSourceType() {
31503174
return source().getType().cast<RankedTensorType>();
31513175
}
31523176

@@ -3167,6 +3191,80 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
31673191
let hasCanonicalizer = 1;
31683192
}
31693193

3194+
//===----------------------------------------------------------------------===//
3195+
// SubTensorInsertOp
3196+
//===----------------------------------------------------------------------===//
3197+
3198+
def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> {
3199+
let summary = "subtensor_insert operation";
3200+
let description = [{
3201+
The "subtensor_insert" operation insert a tensor `source` into another
3202+
tensor `dest` as specified by the operation's offsets, sizes and strides
3203+
arguments.
3204+
3205+
It returns a copy of `dest` with the proper subtensor updated with the value
3206+
of `source`.
3207+
3208+
The subtensor_insert operation has the encodes the following information:
3209+
3210+
* source: the tensor that is inserted.
3211+
* dest: the tensor into which the source tensor is inserted.
3212+
* offsets: tensor-rank number of offsets into the "base" tensor from which
3213+
to extract the subtensor.
3214+
* sizes: tensor-rank number of sizes which specify the sizes of the result
3215+
tensor type.
3216+
* strides: tensor-rank number of strides that specify subsampling in each
3217+
dimension.
3218+
3219+
The representation based on offsets, sizes and strides support a
3220+
partially-static specification via attributes specified through the
3221+
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
3222+
sentinel value ShapedType::kDynamicSize and
3223+
ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
3224+
a static value.
3225+
3226+
After buffer-allocation, the "subtensor_insert" op is expected to become
3227+
an in-place buffer update.
3228+
}];
3229+
3230+
let arguments = (ins
3231+
AnyRankedTensor:$source,
3232+
AnyRankedTensor:$dest,
3233+
Variadic<Index>:$offsets,
3234+
Variadic<Index>:$sizes,
3235+
Variadic<Index>:$strides,
3236+
I64ArrayAttr:$static_offsets,
3237+
I64ArrayAttr:$static_sizes,
3238+
I64ArrayAttr:$static_strides
3239+
);
3240+
let results = (outs AnyRankedTensor:$result);
3241+
3242+
let builders = [
3243+
// Build a SubViewOp with mixed static and dynamic entries.
3244+
OpBuilder<
3245+
"Value source, Value dest, ArrayRef<int64_t> staticOffsets, "
3246+
"ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
3247+
"ValueRange offsets, ValueRange sizes, ValueRange strides, "
3248+
"ArrayRef<NamedAttribute> attrs = {}">,
3249+
// Build a SubViewOp with all dynamic entries.
3250+
OpBuilder<
3251+
"Value source, Value dest, ValueRange offsets, ValueRange sizes, "
3252+
"ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">
3253+
];
3254+
3255+
let extraClassDeclaration = extraBaseClassDeclaration # [{
3256+
/// Returns the type of the base tensor operand.
3257+
RankedTensorType getSourceType() {
3258+
return source().getType().cast<RankedTensorType>();
3259+
}
3260+
3261+
/// The result of a subtensor is always a tensor.
3262+
RankedTensorType getType() {
3263+
return getResult().getType().cast<RankedTensorType>();
3264+
}
3265+
}];
3266+
}
3267+
31703268
//===----------------------------------------------------------------------===//
31713269
// TanhOp
31723270
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Value.h"
2424
#include "mlir/Support/MathExtras.h"
2525
#include "mlir/Transforms/InliningUtils.h"
26+
#include "llvm/ADT/STLExtras.h"
2627
#include "llvm/ADT/StringSwitch.h"
2728
#include "llvm/Support/FormatVariadic.h"
2829
#include "llvm/Support/raw_ostream.h"
@@ -2639,10 +2640,15 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
26392640
/// `:` strided-memref-type `to` strided-memref-type
26402641
/// ```
26412642
template <typename OpType>
2642-
static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
2643+
static void printOpWithOffsetsSizesAndStrides(
2644+
OpAsmPrinter &p, OpType op,
2645+
llvm::function_ref<void(OpAsmPrinter &p, OpType op)> printExtraOperands =
2646+
[](OpAsmPrinter &p, OpType op) {},
2647+
StringLiteral resultTypeKeyword = "to") {
26432648
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
26442649
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
2645-
p << op.getOperand(0);
2650+
p << op.source();
2651+
printExtraOperands(p, op);
26462652
printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
26472653
ShapedType::isDynamicStrideOrOffset);
26482654
printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
@@ -2651,27 +2657,35 @@ static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
26512657
ShapedType::isDynamicStrideOrOffset);
26522658
p.printOptionalAttrDict(op.getAttrs(),
26532659
/*elidedAttrs=*/{OpType::getSpecialAttrNames()});
2654-
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2660+
p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
2661+
<< op.getType();
26552662
}
26562663

26572664
static void print(OpAsmPrinter &p, SubViewOp op) {
26582665
return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
26592666
}
26602667

2661-
/// Parse SubViewOp of the form:
2668+
/// Parse of the form:
26622669
/// ```
2663-
/// `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2664-
/// `:` strided-memref-type `to` strided-memref-type
2670+
/// `name` ssa-name (extra-operands)?
2671+
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2672+
/// `:` strided-memref-type `resultTypeKeyword strided-memref-type
26652673
/// ```
26662674
template <typename OpType>
2667-
static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
2668-
OperationState &result) {
2669-
OpAsmParser::OperandType srcInfo;
2675+
static ParseResult parseOpWithOffsetsSizesAndStrides(
2676+
OpAsmParser &parser, OperationState &result,
2677+
std::function<ParseResult(OpAsmParser &p,
2678+
OpAsmParser::OperandType &dstInfo)>
2679+
parseExtraOperand = nullptr,
2680+
StringLiteral resultTypeKeyword = "to") {
2681+
OpAsmParser::OperandType srcInfo, dstInfo;
26702682
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
26712683
auto indexType = parser.getBuilder().getIndexType();
26722684
Type srcType, dstType;
26732685
if (parser.parseOperand(srcInfo))
26742686
return failure();
2687+
if (parseExtraOperand && parseExtraOperand(parser, dstInfo))
2688+
return failure();
26752689
if (parseListOfOperandsOrIntegers(
26762690
parser, result, OpType::getStaticOffsetsAttrName(),
26772691
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
@@ -2683,21 +2697,27 @@ static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
26832697
ShapedType::kDynamicStrideOrOffset, stridesInfo))
26842698
return failure();
26852699

2700+
// Handle segment sizes.
26862701
auto b = parser.getBuilder();
2687-
SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
2688-
static_cast<int>(sizesInfo.size()),
2689-
static_cast<int>(stridesInfo.size())};
2702+
SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offsetsInfo.size()),
2703+
static_cast<int>(sizesInfo.size()),
2704+
static_cast<int>(stridesInfo.size())};
2705+
// If we parse an extra operand it needs to appear in the segmentSizes
2706+
if (parseExtraOperand)
2707+
segmentSizes.insert(segmentSizes.begin(), 1);
26902708
result.addAttribute(OpType::getOperandSegmentSizeAttr(),
26912709
b.getI32VectorAttr(segmentSizes));
26922710

26932711
return failure(
26942712
parser.parseOptionalAttrDict(result.attributes) ||
26952713
parser.parseColonType(srcType) ||
2714+
parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) ||
26962715
parser.resolveOperand(srcInfo, srcType, result.operands) ||
2716+
(parseExtraOperand &&
2717+
parser.resolveOperand(dstInfo, dstType, result.operands)) ||
26972718
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
26982719
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
26992720
parser.resolveOperands(stridesInfo, indexType, result.operands) ||
2700-
parser.parseKeywordType("to", dstType) ||
27012721
parser.addTypeToList(dstType, result.types));
27022722
}
27032723

@@ -2894,7 +2914,7 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
28942914

28952915
/// Verifier for SubViewOp.
28962916
static LogicalResult verify(SubViewOp op) {
2897-
MemRefType baseType = op.getSourceMemRefType();
2917+
MemRefType baseType = op.getSourceType();
28982918
MemRefType subViewType = op.getType();
28992919

29002920
// The base memref and the view memref should be in the same memory space.
@@ -3273,8 +3293,7 @@ static LogicalResult verify(SubTensorOp op) {
32733293

32743294
// Verify result type against inferred type.
32753295
auto expectedType = SubTensorOp::inferResultType(
3276-
op.getSourceRankedTensorType(),
3277-
extractFromI64ArrayAttr(op.static_offsets()),
3296+
op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
32783297
extractFromI64ArrayAttr(op.static_sizes()),
32793298
extractFromI64ArrayAttr(op.static_strides()));
32803299
if (!isRankReducedType(expectedType, op.getType()))
@@ -3291,6 +3310,72 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
32913310
context);
32923311
}
32933312

3313+
//===----------------------------------------------------------------------===//
3314+
// SubTensorInsertOp
3315+
//===----------------------------------------------------------------------===//
3316+
3317+
static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
3318+
return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
3319+
p, op,
3320+
[](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); },
3321+
/*resultTypeKeyword=*/"into");
3322+
}
3323+
3324+
static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
3325+
OperationState &result) {
3326+
return parseOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
3327+
parser, result,
3328+
[](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) {
3329+
return failure(parser.parseKeyword("into") ||
3330+
parser.parseOperand(dstInfo));
3331+
},
3332+
"into");
3333+
}
3334+
3335+
void mlir::SubTensorInsertOp::build(
3336+
OpBuilder &b, OperationState &result, Value source, Value dest,
3337+
ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
3338+
ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes,
3339+
ValueRange strides, ArrayRef<NamedAttribute> attrs) {
3340+
build(b, result, dest.getType(), source, dest, offsets, sizes, strides,
3341+
b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
3342+
b.getI64ArrayAttr(staticStrides));
3343+
result.addAttributes(attrs);
3344+
}
3345+
3346+
/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
3347+
/// and `staticStrides` are automatically filled with source-memref-rank
3348+
/// sentinel values that encode dynamic entries.
3349+
void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
3350+
Value source, Value dest,
3351+
ValueRange offsets, ValueRange sizes,
3352+
ValueRange strides,
3353+
ArrayRef<NamedAttribute> attrs) {
3354+
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
3355+
unsigned rank = sourceRankedTensorType.getRank();
3356+
SmallVector<int64_t, 4> staticOffsetsVector(
3357+
rank, ShapedType::kDynamicStrideOrOffset);
3358+
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
3359+
SmallVector<int64_t, 4> staticStridesVector(
3360+
rank, ShapedType::kDynamicStrideOrOffset);
3361+
build(b, result, source, dest, staticOffsetsVector, staticSizesVector,
3362+
staticStridesVector, offsets, sizes, strides, attrs);
3363+
}
3364+
3365+
SmallVector<Range, 8> SubTensorInsertOp::getOrCreateRanges(OpBuilder &b,
3366+
Location loc) {
3367+
return ::getOrCreateRangesImpl(*this, b, loc);
3368+
}
3369+
3370+
/// Verifier for SubViewOp.
3371+
static LogicalResult verify(SubTensorInsertOp op) {
3372+
if (failed(verifyOpWithOffsetSizesAndStrides(op)))
3373+
return failure();
3374+
if (op.getType() != op.dest().getType())
3375+
return op.emitError("expected result type to be ") << op.dest().getType();
3376+
return success();
3377+
}
3378+
32943379
//===----------------------------------------------------------------------===//
32953380
// TensorCastOp
32963381
//===----------------------------------------------------------------------===//

mlir/test/IR/core-ops.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,6 @@ func @assume_alignment(%0: memref<4x4xf16>) {
901901
return
902902
}
903903

904-
905904
// CHECK-LABEL: func @subtensor({{.*}}) {
906905
func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
907906
%c0 = constant 0 : index
@@ -924,3 +923,21 @@ func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
924923

925924
return
926925
}
926+
927+
// CHECK-LABEL: func @subtensor_insert({{.*}}) {
928+
func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) {
929+
%c0 = constant 0 : index
930+
%c1 = constant 1 : index
931+
932+
// CHECK: subtensor_insert
933+
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
934+
%1 = subtensor_insert %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
935+
: tensor<8x16x4xf32> into tensor<16x32x8xf32>
936+
937+
// CHECK: subtensor_insert
938+
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
939+
%2 = subtensor_insert %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
940+
: tensor<8x16x4xf32> into tensor<16x32x8xf32>
941+
942+
return
943+
}

0 commit comments

Comments
 (0)