Skip to content

Commit 118a715

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Define a linalg.init_tensor operation.
This operation is used to materialize a tensor of a particular shape. The shape could be specified as a mix of static and dynamic values. The use of this operation is to be an `init` tensor for Linalg structured operation on tensors where the bounds of the computation depends on the shape of the output of the linalg operation. The result of this operation will be used as the `init` tensor of such Linalg operations. To note, 1) The values in the tensor materialized is not used. Any operation to which this is an init tensor is expected to overwrite the entire tensor. 2) The tensor is materialized only for the shape of the output and to make the loop bounds depend only on operands of the structured operation. Based on (1) and (2) it is assumed that these operations eventually go away since they are only used in `dim` operations that can be canonicalized to make this operation dead. Such canonicalization are added here too. Differential Revision: https://reviews.llvm.org/D93374
1 parent de03121 commit 118a715

File tree

6 files changed

+350
-42
lines changed

6 files changed

+350
-42
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,91 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
3232
let parser = [{ return ::parse$cppClass(parser, result); }];
3333
}
3434

35+
def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
36+
let summary = "operation to define a tensor of particular value";
37+
38+
let description = [{
39+
`linalg.init_tensor` is an operation that materializes a tensor of
40+
a given shape. The shape could be dynamic or static.
41+
}];
42+
43+
let arguments =
44+
(ins Variadic<Index>:$sizes, I64ArrayAttr:$static_sizes);
45+
46+
let results = (outs AnyTensor:$result);
47+
48+
let verifier = [{ return ::verify(*this); }];
49+
50+
let extraClassDeclaration = [{
51+
static StringRef getStaticSizesAttrName() {
52+
return "static_sizes";
53+
}
54+
55+
RankedTensorType getType() {
56+
return getResult().getType().cast<RankedTensorType>(); }
57+
58+
// Infer the shape of the result tensor given the static shapes
59+
// and element type of the result tensor.
60+
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
61+
62+
// Return true if the size of the tensor is dynamic at `idx`
63+
bool isDynamicSize(unsigned idx) {
64+
APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
65+
return ShapedType::isDynamic(v.getSExtValue());
66+
}
67+
68+
// Assert that the size of the result tensor is static at `idx`
69+
// and return the shape.
70+
int64_t getStaticSize(unsigned idx) {
71+
assert(!isDynamicSize(idx) && "expected static size");
72+
APInt v = *(static_sizes().
73+
template getAsValueRange<IntegerAttr>().begin() + idx);
74+
return v.getSExtValue();
75+
}
76+
77+
// Return the argument position that contains the dynamic size of
78+
// the tensor at dimension `idx`. Asserts that the shape is
79+
// dynamic at that `idx`.
80+
unsigned getIndexOfDynamicSize(unsigned idx) {
81+
assert(isDynamicSize(idx) && "expected dynamic size");
82+
return std::count_if(
83+
static_sizes().getValue().begin(),
84+
static_sizes().getValue().begin() + idx,
85+
[&](Attribute attr) {
86+
return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
87+
});
88+
}
89+
90+
// Return the Value of the dynamic size of the tensor at dimension
91+
// `idx`. Asserts that the shape is dynamic at that `idx.
92+
Value getDynamicSize(unsigned idx) {
93+
return getOperand(getIndexOfDynamicSize(idx));
94+
}
95+
}];
96+
97+
let builders = [
98+
OpBuilderDAG<(ins "ValueRange":$shape,
99+
"ArrayRef<int64_t>":$staticShape, "Type":$elementType),
100+
[{
101+
build($_builder, $_state,
102+
InitTensorOp::inferResultType(staticShape, elementType),
103+
shape, $_builder.getI64ArrayAttr(staticShape));
104+
}]>,
105+
OpBuilderDAG<(ins "ValueRange":$shape, "Type":$elementType),
106+
[{
107+
SmallVector<int64_t, 4> staticShape(
108+
shape.size(), ShapedType::kDynamicSize);
109+
build($_builder, $_state, shape, staticShape, elementType);
110+
}]>,
111+
OpBuilderDAG<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
112+
[{
113+
build($_builder, $_state, ValueRange{}, staticShape, elementType);
114+
}]>
115+
];
116+
117+
let hasCanonicalizer = 1;
118+
}
119+
35120
def Linalg_RangeOp :
36121
Linalg_Op<"range", [NoSideEffect]>,
37122
Arguments<(ins Index:$min, Index:$max, Index:$step)>,

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ LogicalResult verify(OffsetSizeAndStrideOpInterface op);
3535
#include "mlir/Interfaces/ViewLikeInterface.h.inc"
3636

3737
namespace mlir {
38+
/// Print a list with either (1) the static integer value in `arrayAttr` if
39+
/// `isDynamic` evaluates to false or (2) the next value otherwise.
40+
/// This allows idiomatic printing of mixed value and integer attributes in a
41+
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
42+
void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
43+
ArrayAttr arrayAttr,
44+
llvm::function_ref<bool(int64_t)> isDynamic);
45+
3846
/// Print part of an op of the form:
3947
/// ```
4048
/// <optional-offset-prefix>`[` offset-list `]`
@@ -48,6 +56,19 @@ void printOffsetsSizesAndStrides(
4856
ArrayRef<StringRef> elidedAttrs =
4957
OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
5058

59+
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
60+
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
61+
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
62+
/// in-order.
63+
//
64+
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
65+
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
66+
/// 2. `ssa` is filled with "[%arg0, %arg1]".
67+
ParseResult
68+
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
69+
StringRef attrName, int64_t dynVal,
70+
SmallVectorImpl<OpAsmParser::OperandType> &ssa);
71+
5172
/// Parse trailing part of an op of the form:
5273
/// ```
5374
/// <optional-offset-prefix>`[` offset-list `]`
@@ -87,6 +108,12 @@ ParseResult parseOffsetsSizesAndStrides(
87108
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
88109
nullptr);
89110

111+
/// Verify that a the `values` has as many elements as the number of entries in
112+
/// `attr` for which `isDynamic` evaluates to true.
113+
LogicalResult verifyListOfOperandsOrIntegers(
114+
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
115+
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
116+
90117
} // namespace mlir
91118

92119
#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,145 @@ static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
550550

551551
static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
552552

553+
//===----------------------------------------------------------------------===//
554+
// InitTensorOp
555+
//===----------------------------------------------------------------------===//
556+
557+
static ParseResult parseInitTensorOp(OpAsmParser &parser,
558+
OperationState &result) {
559+
OpAsmParser::OperandType srcInfo;
560+
Type dstType;
561+
SmallVector<OpAsmParser::OperandType, 2> sizeInfo;
562+
IndexType indexType = parser.getBuilder().getIndexType();
563+
if (failed(parseListOfOperandsOrIntegers(
564+
parser, result, InitTensorOp::getStaticSizesAttrName(),
565+
ShapedType::kDynamicSize, sizeInfo)) ||
566+
failed(parser.parseOptionalAttrDict(result.attributes)) ||
567+
failed(parser.parseColonType(dstType)) ||
568+
failed(parser.resolveOperands(sizeInfo, indexType, result.operands)))
569+
return failure();
570+
return parser.addTypeToList(dstType, result.types);
571+
}
572+
573+
static void print(OpAsmPrinter &p, InitTensorOp op) {
574+
p << op.getOperation()->getName() << ' ';
575+
printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
576+
ShapedType::isDynamic);
577+
p.printOptionalAttrDict(op.getAttrs(),
578+
InitTensorOp::getStaticSizesAttrName());
579+
p << " : " << op.getType();
580+
}
581+
582+
static LogicalResult verify(InitTensorOp op) {
583+
RankedTensorType resultType = op.getType();
584+
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
585+
op.static_sizes().cast<ArrayAttr>(),
586+
[](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
587+
588+
if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
589+
op.static_sizes(), op.sizes(),
590+
ShapedType::isDynamic)))
591+
return failure();
592+
593+
Type expectedType =
594+
InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
595+
if (resultType != expectedType) {
596+
return op.emitError("specified type ")
597+
<< resultType << " does not match the inferred type "
598+
<< expectedType;
599+
}
600+
return success();
601+
}
602+
603+
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
604+
Type elementType) {
605+
return RankedTensorType::get(staticSizes, elementType);
606+
}
607+
608+
namespace {
609+
/// Change the type of the result of a `linalg.init_tensor` by making the result
610+
/// type statically sized along dimension that in the original operation where
611+
/// defined as dynamic, but the size was defined using a `constant` op. For
612+
/// example
613+
///
614+
/// %c5 = constant 5: index
615+
/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
616+
///
617+
/// to
618+
///
619+
/// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
620+
struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
621+
using OpRewritePattern<InitTensorOp>::OpRewritePattern;
622+
623+
LogicalResult matchAndRewrite(InitTensorOp op,
624+
PatternRewriter &rewriter) const override {
625+
SmallVector<Value, 4> dynamicSizes;
626+
SmallVector<int64_t, 4> staticSizes;
627+
for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
628+
// If the size is already static, nothing to do.
629+
if (!op.isDynamicSize(i)) {
630+
staticSizes.push_back(op.getStaticSize(i));
631+
continue;
632+
}
633+
634+
// If the size is dynamic but defined using a `constant` op, get the
635+
// constant value to find the static size to use.
636+
unsigned operandNum = op.getIndexOfDynamicSize(i);
637+
Value sizeOperand = op.getOperand(operandNum);
638+
if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
639+
staticSizes.push_back(constantIndexOp.getValue());
640+
continue;
641+
}
642+
643+
// Fallback case. Keep the size dynamic.
644+
dynamicSizes.push_back(sizeOperand);
645+
staticSizes.push_back(ShapedType::kDynamicSize);
646+
}
647+
RankedTensorType newType =
648+
RankedTensorType::get(staticSizes, op.getType().getElementType());
649+
if (newType == op.getType())
650+
return failure();
651+
auto newOp =
652+
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
653+
rewriter.getI64ArrayAttr(staticSizes));
654+
rewriter.replaceOpWithNewOp<TensorCastOp>(op, op.getType(), newOp);
655+
return success();
656+
}
657+
};
658+
659+
/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
660+
/// with
661+
/// - A constant value if the size is static along the dimension.
662+
/// - The dynamic value that defines the size of the result of
663+
/// `linalg.init_tensor` op.
664+
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
665+
using OpRewritePattern<DimOp>::OpRewritePattern;
666+
667+
LogicalResult matchAndRewrite(DimOp dimOp,
668+
PatternRewriter &rewriter) const override {
669+
auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
670+
if (!initTensorOp)
671+
return failure();
672+
auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
673+
if (!dimIndex)
674+
return failure();
675+
int64_t index = dimIndex.getValue();
676+
if (!initTensorOp.isDynamicSize(index)) {
677+
rewriter.replaceOpWithNewOp<ConstantIndexOp>(
678+
dimOp, initTensorOp.getStaticSize(index));
679+
} else {
680+
rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
681+
}
682+
return success();
683+
}
684+
};
685+
} // namespace
686+
687+
void InitTensorOp::getCanonicalizationPatterns(
688+
OwningRewritePatternList &results, MLIRContext *context) {
689+
results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
690+
}
691+
553692
//===----------------------------------------------------------------------===//
554693
// ReshapeOp
555694
//===----------------------------------------------------------------------===//

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,43 @@ using namespace mlir;
1717
/// Include the definitions of the loop-like interfaces.
1818
#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
1919

20-
static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
21-
OffsetSizeAndStrideOpInterface op, StringRef name,
22-
unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
23-
llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
20+
LogicalResult mlir::verifyListOfOperandsOrIntegers(
21+
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
22+
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
2423
/// Check static and dynamic offsets/sizes/strides breakdown.
2524
if (attr.size() != expectedNumElements)
26-
return op.emitError("expected ")
25+
return op->emitError("expected ")
2726
<< expectedNumElements << " " << name << " values";
2827
unsigned expectedNumDynamicEntries =
2928
llvm::count_if(attr.getValue(), [&](Attribute attr) {
3029
return isDynamic(attr.cast<IntegerAttr>().getInt());
3130
});
3231
if (values.size() != expectedNumDynamicEntries)
33-
return op.emitError("expected ")
32+
return op->emitError("expected ")
3433
<< expectedNumDynamicEntries << " dynamic " << name << " values";
3534
return success();
3635
}
3736

3837
LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
3938
std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
40-
if (failed(verifyOpWithOffsetSizesAndStridesPart(
41-
op, "offset", ranks[0],
42-
OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
43-
op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
44-
op.offsets())))
39+
if (failed(verifyListOfOperandsOrIntegers(
40+
op, "offset", ranks[0], op.static_offsets(), op.offsets(),
41+
ShapedType::isDynamicStrideOrOffset)))
4542
return failure();
46-
if (failed(verifyOpWithOffsetSizesAndStridesPart(
47-
op, "size", ranks[1],
48-
OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
49-
op.static_sizes(), ShapedType::isDynamic, op.sizes())))
43+
if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1],
44+
op.static_sizes(), op.sizes(),
45+
ShapedType::isDynamic)))
5046
return failure();
51-
if (failed(verifyOpWithOffsetSizesAndStridesPart(
52-
op, "stride", ranks[2],
53-
OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
54-
op.static_strides(), ShapedType::isDynamicStrideOrOffset,
55-
op.strides())))
47+
if (failed(verifyListOfOperandsOrIntegers(
48+
op, "stride", ranks[2], op.static_strides(), op.strides(),
49+
ShapedType::isDynamicStrideOrOffset)))
5650
return failure();
5751
return success();
5852
}
5953

60-
/// Print a list with either (1) the static integer value in `arrayAttr` if
61-
/// `isDynamic` evaluates to false or (2) the next value otherwise.
62-
/// This allows idiomatic printing of mixed value and integer attributes in a
63-
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
64-
static void
65-
printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
66-
ArrayAttr arrayAttr,
67-
llvm::function_ref<bool(int64_t)> isDynamic) {
54+
void mlir::printListOfOperandsOrIntegers(
55+
OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
56+
llvm::function_ref<bool(int64_t)> isDynamic) {
6857
p << '[';
6958
unsigned idx = 0;
7059
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
@@ -95,18 +84,9 @@ void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
9584
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
9685
}
9786

98-
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
99-
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
100-
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
101-
/// in-order.
102-
//
103-
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
104-
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
105-
/// 2. `ssa` is filled with "[%arg0, %arg1]".
106-
static ParseResult
107-
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
108-
StringRef attrName, int64_t dynVal,
109-
SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
87+
ParseResult mlir::parseListOfOperandsOrIntegers(
88+
OpAsmParser &parser, OperationState &result, StringRef attrName,
89+
int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
11090
if (failed(parser.parseLSquare()))
11191
return failure();
11292
// 0-D.

0 commit comments

Comments
 (0)