Skip to content

Commit cfc9dda

Browse files
[mlir][interfaces][NFC] Move DestinationStyleOpInterface to mlir/Interfaces
This is the second (and final) step of making "destination style" usable without depending on the Linalg dialect. (The first step was D135129.) This change allows us to provide default bufferization implementations for all destination-style ops. It also allows us to simplify `TilingInterface`. (E.g., `getDestinationOperands` can be removed.) Differential Revision: https://reviews.llvm.org/D136179
1 parent 44027f3 commit cfc9dda

File tree

17 files changed

+467
-357
lines changed

17 files changed

+467
-357
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/TypeUtilities.h"
2121
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2222
#include "mlir/Interfaces/CopyOpInterface.h"
23+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2324
#include "mlir/Interfaces/InferTypeOpInterface.h"
2425
#include "mlir/Interfaces/SideEffectInterfaces.h"
2526
#include "mlir/Interfaces/TilingInterface.h"

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/ImplicitLocOpBuilder.h"
2121
#include "mlir/IR/OpDefinition.h"
22+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2223
#include "mlir/Interfaces/InferTypeOpInterface.h"
2324
#include "mlir/Interfaces/ViewLikeInterface.h"
2425

2526
namespace mlir {
2627
namespace linalg {
2728
class LinalgOp;
2829

29-
/// OpOperand vector that implicitly converts to a Value vector.
30-
struct OpOperandVector : public SmallVector<OpOperand *> {
31-
operator SmallVector<Value>();
32-
};
33-
3430
namespace detail {
3531
/// Implementation of the method that that check if given operands
3632
/// can be dropped, i.e. the remaining operands can compute the loop
@@ -57,9 +53,6 @@ LogicalResult verifyFillInterface(Operation *op);
5753
/// Verify that `op` conforms to the invariants of StructuredOpInterface
5854
LogicalResult verifyStructuredOpInterface(Operation *op);
5955

60-
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
61-
LogicalResult verifyDestinationStyleOpInterface(Operation *op);
62-
6356
} // namespace detail
6457
} // namespace linalg
6558
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 0 additions & 287 deletions
Original file line numberDiff line numberDiff line change
@@ -879,291 +879,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
879879
let verifyWithRegions = 1;
880880
}
881881

882-
// Ops that are in destination style have designated output operands, which act
883-
// as initial tensor values for the results of the operation or the output
884-
// buffers to which the results of the op will be written.
885-
//
886-
// Output operands must be tensors or memrefs. Input operands can have any
887-
// type. All non-output operands are inputs.
888-
889-
// It is assumed that the output operands of the op are the operands at
890-
// position [start, end). The positions are defined by getOutputsPositionRange
891-
// method. All non-output operands are "inputs" of the DPS op.
892-
893-
// If the op has "tensor semantics", then the input operands are either scalars
894-
// or tensors. The output operands are tensors and every tensor output is tied
895-
// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
896-
// tensor is tied to the i-th OpResult. The op may not have any additional
897-
// OpResults. Output operands and their tied OpResults have the same type.
898-
//
899-
// If the op has "buffer semantics", then the input operands are either memrefs
900-
// or other non-tensor types, e.g. scalar types. Furthermore, the output
901-
// operands are memrefs and the op has no results.
902-
//
903-
// Destination-passing style abstraction makes certain transformations easier.
904-
// For example, tiling implementation can extract/insert slices from/into the
905-
// destination of an op and use the resulting shaped value as an iter_arg in
906-
// the surrounding loop structure. As another example, bufferization does not
907-
// have to allocate new buffers for destinations (in case of in-place
908-
// bufferization) and can directly reuse the existing destination buffer.
909-
//
910-
// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
911-
// where `%t` is the single input and `%d` is the single output. `%d` is tied
912-
// to `%r`.
913-
//
914-
// Example of an op that is not in destination style: `%r = tensor.pad %t`.
915-
// This op is not in destination style because `%r` and `%t` have different
916-
// shape.
917-
//
918-
// Each op that wants to implement DestinationStyleOpInterface needs to define
919-
// the getOutputsPositionRange() method.
920-
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
921-
let cppNamespace = "::mlir::linalg";
922-
let methods = [
923-
// This method has to be defined for every DPS op.
924-
InterfaceMethod<
925-
/*desc=*/"Return start and end indices of the output operands range.",
926-
/*retTy=*/"std::pair<int64_t, int64_t>",
927-
/*methodName=*/"getOutputsPositionRange",
928-
/*args=*/(ins),
929-
/*methodBody=*/"",
930-
/*defaultImplementation=*/""
931-
>,
932-
//===------------------------------------------------------------------===//
933-
// Operands handling.
934-
//===------------------------------------------------------------------===//
935-
// The operand list is assumed to start with the input operands and end
936-
// with the output operands. Therefore, all methods to access the inputs
937-
// and outputs can be expressed if the number of output operands is know.
938-
InterfaceMethod<
939-
/*desc=*/"Return the number of outputs.",
940-
/*retTy=*/"int64_t",
941-
/*methodName=*/"getNumOutputs",
942-
/*args=*/(ins),
943-
/*methodBody=*/"",
944-
/*defaultImplementation=*/[{
945-
auto [start, end] = $_op.getOutputsPositionRange();
946-
return end - start;
947-
}]
948-
>,
949-
InterfaceMethod<
950-
/*desc=*/"Return the output operands.",
951-
/*retTy=*/"OpOperandVector",
952-
/*methodName=*/"getOutputOperands",
953-
/*args=*/(ins),
954-
/*methodBody=*/"",
955-
/*defaultImplementation=*/[{
956-
auto [start, end] = $_op.getOutputsPositionRange();
957-
958-
OpOperandVector result;
959-
result.reserve(end - start);
960-
for (int i = start; i < end; ++i)
961-
result.push_back(&$_op->getOpOperand(i));
962-
return result;
963-
}]
964-
>,
965-
InterfaceMethod<
966-
/*desc=*/"Return the `i`-th output operand.",
967-
/*retTy=*/"OpOperand*",
968-
/*methodName=*/"getOutputOperand",
969-
/*args=*/(ins "int64_t":$i),
970-
/*methodBody=*/"",
971-
/*defaultImplementation=*/[{
972-
assert(i >= 0 && i < $_op.getNumOutputs());
973-
auto [start, end] = $_op.getOutputsPositionRange();
974-
return &$_op->getOpOperand(start + i);
975-
}]
976-
>,
977-
InterfaceMethod<
978-
/*desc=*/"Set the `i`-th output operand.",
979-
/*retTy=*/"void",
980-
/*methodName=*/"setOutputOperand",
981-
/*args=*/(ins "int64_t":$i, "Value":$value),
982-
/*methodBody=*/"",
983-
/*defaultImplementation=*/[{
984-
assert(i >= 0 && i < $_op.getNumOutputs());
985-
auto [start, end] = $_op.getOutputsPositionRange();
986-
$_op->setOperand(start + i, value);
987-
}]
988-
>,
989-
InterfaceMethod<
990-
/*desc=*/"Return the number of inputs.",
991-
/*retTy=*/"int64_t",
992-
/*methodName=*/"getNumInputs",
993-
/*args=*/(ins),
994-
/*methodBody=*/"",
995-
/*defaultImplementation=*/[{
996-
return $_op.getNumOperands() - $_op.getNumOutputs();
997-
}]
998-
>,
999-
InterfaceMethod<
1000-
/*desc=*/"Return the input operands.",
1001-
/*retTy=*/"OpOperandVector",
1002-
/*methodName=*/"getInputOperands",
1003-
/*args=*/(ins),
1004-
/*methodBody=*/"",
1005-
/*defaultImplementation=*/[{
1006-
auto [start, end] = $_op.getOutputsPositionRange();
1007-
int64_t numOutputs = end - start;
1008-
int64_t numOperands = $_op.getNumOperands();
1009-
1010-
OpOperandVector result;
1011-
result.reserve(numOperands - numOutputs);
1012-
for (int i = 0; i < start; ++i)
1013-
result.push_back(&$_op->getOpOperand(i));
1014-
for (int i = end; i < numOperands; ++i)
1015-
result.push_back(&$_op->getOpOperand(end + i));
1016-
1017-
return result;
1018-
}]
1019-
>,
1020-
InterfaceMethod<
1021-
/*desc=*/[{ Return the `i`-th input operand. }],
1022-
/*retTy=*/"OpOperand*",
1023-
/*methodName=*/"getInputOperand",
1024-
/*args=*/(ins "int64_t":$i),
1025-
/*methodBody=*/"",
1026-
/*defaultImplementation=*/[{
1027-
assert(i >= 0 && i < getNumInputs());
1028-
auto [start, end] = $_op.getOutputsPositionRange();
1029-
return &$_op->getOpOperand(i < start ? i : i + end - start) ;
1030-
}]
1031-
>,
1032-
//===------------------------------------------------------------------===//
1033-
// Input and Output arguments handling.
1034-
//===------------------------------------------------------------------===//
1035-
InterfaceMethod<
1036-
/*desc=*/"Return true if `opOperand` is an input.",
1037-
/*retTy=*/"bool",
1038-
/*methodName=*/"isInput",
1039-
/*args=*/(ins "OpOperand *":$opOperand),
1040-
/*methodBody=*/"",
1041-
/*defaultImplementation=*/[{
1042-
auto [start, end] = $_op.getOutputsPositionRange();
1043-
auto operandNumber = opOperand->getOperandNumber();
1044-
return operandNumber < start || operandNumber >= end;
1045-
}]
1046-
>,
1047-
InterfaceMethod<
1048-
/*desc=*/"Return true if `opOperand` is an output.",
1049-
/*retTy=*/"bool",
1050-
/*methodName=*/"isOutput",
1051-
/*args=*/(ins "OpOperand *":$opOperand),
1052-
/*methodBody=*/"",
1053-
/*defaultImplementation=*/[{
1054-
auto [start, end] = $_op.getOutputsPositionRange();
1055-
auto operandNumber = opOperand->getOperandNumber();
1056-
return operandNumber >= start && operandNumber < end;
1057-
}]
1058-
>,
1059-
InterfaceMethod<
1060-
/*desc=*/"Return true if the `opOperand` is a scalar value.",
1061-
/*retTy=*/"bool",
1062-
/*methodName=*/"isScalar",
1063-
/*args=*/(ins "OpOperand*":$opOperand),
1064-
/*methodBody=*/"",
1065-
/*defaultImplementation=*/[{
1066-
assert(opOperand->getOwner() == this->getOperation());
1067-
return !opOperand->get().getType().template isa<ShapedType>();
1068-
}]
1069-
>,
1070-
InterfaceMethod<
1071-
/*desc=*/"Return the result tied to `opOperand`.",
1072-
/*retTy=*/"OpResult",
1073-
/*methodName=*/"getTiedOpResult",
1074-
/*args=*/(ins "OpOperand*":$opOperand),
1075-
/*methodBody=*/"",
1076-
/*defaultImplementation=*/[{
1077-
assert(opOperand->getOwner() == this->getOperation());
1078-
1079-
auto [start, end] = $_op.getOutputsPositionRange();
1080-
int64_t resultIndex = opOperand->getOperandNumber() - start;
1081-
assert(resultIndex >= 0 &&
1082-
resultIndex < $_op->getNumResults() );
1083-
return $_op->getResult(resultIndex);
1084-
}]
1085-
>,
1086-
//===------------------------------------------------------------------===//
1087-
// Other interface methods.
1088-
//===------------------------------------------------------------------===//
1089-
InterfaceMethod<
1090-
/*desc=*/"Return whether the op has only MemRef input and outputs.",
1091-
/*retTy=*/"bool",
1092-
/*methodName=*/"hasBufferSemantics",
1093-
/*args=*/(ins),
1094-
/*methodBody=*/"",
1095-
/*defaultImplementation=*/[{
1096-
return $_op->getNumResults() == 0 &&
1097-
llvm::all_of($_op->getOpOperands(),
1098-
[&](OpOperand &opOperand) {
1099-
return isScalar(&opOperand) ||
1100-
opOperand.get().getType().template isa<MemRefType>();
1101-
});
1102-
}]
1103-
>,
1104-
InterfaceMethod<
1105-
/*desc=*/"Return whether the op has only RankedTensor input and outputs.",
1106-
/*retTy=*/"bool",
1107-
/*methodName=*/"hasTensorSemantics",
1108-
/*args=*/(ins),
1109-
/*methodBody=*/"",
1110-
/*defaultImplementation=*/[{
1111-
return llvm::all_of($_op->getOpOperands(),
1112-
[&](OpOperand &opOperand) {
1113-
return isScalar(&opOperand) ||
1114-
opOperand.get().getType().template isa<RankedTensorType>();
1115-
});
1116-
}]
1117-
>,
1118-
//===------------------------------------------------------------------===//
1119-
// Other static interface methods.
1120-
//===------------------------------------------------------------------===//
1121-
InterfaceMethod<
1122-
/*desc=*/[{
1123-
Clone the current operation with the given location and operands. This
1124-
is used to abstract away the optional underlying region creation. This
1125-
does not change the balance between input, output_buffer and
1126-
init_tensors operands.
1127-
}],
1128-
/*retTy=*/"Operation *",
1129-
/*methodName=*/"clone",
1130-
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
1131-
"ValueRange":$operands),
1132-
[{
1133-
BlockAndValueMapping bvm;
1134-
OperationState state(
1135-
loc, ConcreteOp::getOperationName(), operands, resultTypes,
1136-
$_op->getAttrs());
1137-
for (Region &r : $_op->getRegions())
1138-
r.cloneInto(state.addRegion(), bvm);
1139-
return b.create(state);
1140-
}]
1141-
>,
1142-
InterfaceMethod<
1143-
/*desc=*/[{
1144-
Clone the current operation with the given location, operands
1145-
and BlockAndValueMapping but leave the regions empty. This is
1146-
used to abstract away the optional underlying region creation.
1147-
This does not change the balance between input, output_buffer
1148-
and init_tensors operands.
1149-
}],
1150-
/*retTy=*/"Operation *",
1151-
/*methodName=*/"cloneWithoutRegions",
1152-
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
1153-
"ValueRange":$operands),
1154-
[{
1155-
OperationState state(
1156-
loc, ConcreteOp::getOperationName(), operands, resultTypes,
1157-
$_op->getAttrs());
1158-
for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
1159-
state.addRegion();
1160-
return b.create(state);
1161-
}]
1162-
>
1163-
];
1164-
1165-
let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
1166-
let verifyWithRegions = 1;
1167-
}
1168-
1169882
#endif // LINALG_IR_LINALGINTERFACES

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1818
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1919
include "mlir/Interfaces/ControlFlowInterfaces.td"
20+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
2021
include "mlir/Interfaces/InferTypeOpInterface.td"
2122
include "mlir/Interfaces/SideEffectInterfaces.td"
2223
include "mlir/IR/OpAsmInterface.td"
@@ -279,7 +280,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
279280
int64_t getNumOperands = this->getNumOperands();
280281
return {getNumOperands - 1, getNumOperands};
281282
}
282-
linalg::OpOperandVector getOpOperandsMatchingBBargs() {
283+
OpOperandVector getOpOperandsMatchingBBargs() {
283284
return getInputOperands();
284285
}
285286

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
33
add_mlir_interface(ControlFlowInterfaces)
44
add_mlir_interface(CopyOpInterface)
55
add_mlir_interface(DerivedAttributeOpInterface)
6+
add_mlir_interface(DestinationStyleOpInterface)
67
add_mlir_interface(InferIntRangeInterface)
78
add_mlir_interface(InferTypeOpInterface)
89
add_mlir_interface(LoopLikeInterface)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- DestinationStyleOpInterface.h ----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
10+
#define MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
11+
12+
#include "mlir/IR/BlockAndValueMapping.h"
13+
#include "mlir/IR/Builders.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/IR/Value.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
19+
namespace mlir {
20+
/// OpOperand vector that implicitly converts to a Value vector.
21+
struct OpOperandVector : public llvm::SmallVector<OpOperand *> {
22+
operator SmallVector<Value>();
23+
};
24+
25+
namespace detail {
26+
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
27+
LogicalResult verifyDestinationStyleOpInterface(Operation *op);
28+
} // namespace detail
29+
} // namespace mlir
30+
31+
/// Include the generated interface declarations.
32+
#include "mlir/Interfaces/DestinationStyleOpInterface.h.inc"
33+
34+
#endif // MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_

0 commit comments

Comments
 (0)