@@ -879,291 +879,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
879
879
let verifyWithRegions = 1;
880
880
}
881
881
882
- // Ops that are in destination style have designated output operands, which act
883
- // as initial tensor values for the results of the operation or the output
884
- // buffers to which the results of the op will be written.
885
- //
886
- // Output operands must be tensors or memrefs. Input operands can have any
887
- // type. All non-output operands are inputs.
888
-
889
- // It is assumed that the output operands of the op are the operands at
890
- // position [start, end). The positions are defined by getOutputsPositionRange
891
- // method. All non-output operands are "inputs" of the DPS op.
892
-
893
- // If the op has "tensor semantics", then the input operands are either scalars
894
- // or tensors. The output operands are tensors and every tensor output is tied
895
- // to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
896
- // tensor is tied to the i-th OpResult. The op may not have any additional
897
- // OpResults. Output operands and their tied OpResults have the same type.
898
- //
899
- // If the op has "buffer semantics", then the input operands are either memrefs
900
- // or other non-tensor types, e.g. scalar types. Furthermore, the output
901
- // operands are memrefs and the op has no results.
902
- //
903
- // Destination-passing style abstraction makes certain transformations easier.
904
- // For example, tiling implementation can extract/insert slices from/into the
905
- // destination of an op and use the resulting shaped value as an iter_arg in
906
- // the surrounding loop structure. As another example, bufferization does not
907
- // have to allocate new buffers for destinations (in case of in-place
908
- // bufferization) and can directly reuse the existing destination buffer.
909
- //
910
- // Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
911
- // where `%t` is the single input and `%d` is the single output. `%d` is tied
912
- // to `%r`.
913
- //
914
- // Example of an op that is not in destination style: `%r = tensor.pad %t`.
915
- // This op is not in destination style because `%r` and `%t` have different
916
- // shape.
917
- //
918
- // Each op that wants to implement DestinationStyleOpInterface needs to define
919
- // the getOutputsPositionRange() method.
920
- def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
921
- let cppNamespace = "::mlir::linalg";
922
- let methods = [
923
- // This method has to be defined for every DPS op.
924
- InterfaceMethod<
925
- /*desc=*/"Return start and end indices of the output operands range.",
926
- /*retTy=*/"std::pair<int64_t, int64_t>",
927
- /*methodName=*/"getOutputsPositionRange",
928
- /*args=*/(ins),
929
- /*methodBody=*/"",
930
- /*defaultImplementation=*/""
931
- >,
932
- //===------------------------------------------------------------------===//
933
- // Operands handling.
934
- //===------------------------------------------------------------------===//
935
- // The operand list is assumed to start with the input operands and end
936
- // with the output operands. Therefore, all methods to access the inputs
937
- // and outputs can be expressed if the number of output operands is know.
938
- InterfaceMethod<
939
- /*desc=*/"Return the number of outputs.",
940
- /*retTy=*/"int64_t",
941
- /*methodName=*/"getNumOutputs",
942
- /*args=*/(ins),
943
- /*methodBody=*/"",
944
- /*defaultImplementation=*/[{
945
- auto [start, end] = $_op.getOutputsPositionRange();
946
- return end - start;
947
- }]
948
- >,
949
- InterfaceMethod<
950
- /*desc=*/"Return the output operands.",
951
- /*retTy=*/"OpOperandVector",
952
- /*methodName=*/"getOutputOperands",
953
- /*args=*/(ins),
954
- /*methodBody=*/"",
955
- /*defaultImplementation=*/[{
956
- auto [start, end] = $_op.getOutputsPositionRange();
957
-
958
- OpOperandVector result;
959
- result.reserve(end - start);
960
- for (int i = start; i < end; ++i)
961
- result.push_back(&$_op->getOpOperand(i));
962
- return result;
963
- }]
964
- >,
965
- InterfaceMethod<
966
- /*desc=*/"Return the `i`-th output operand.",
967
- /*retTy=*/"OpOperand*",
968
- /*methodName=*/"getOutputOperand",
969
- /*args=*/(ins "int64_t":$i),
970
- /*methodBody=*/"",
971
- /*defaultImplementation=*/[{
972
- assert(i >= 0 && i < $_op.getNumOutputs());
973
- auto [start, end] = $_op.getOutputsPositionRange();
974
- return &$_op->getOpOperand(start + i);
975
- }]
976
- >,
977
- InterfaceMethod<
978
- /*desc=*/"Set the `i`-th output operand.",
979
- /*retTy=*/"void",
980
- /*methodName=*/"setOutputOperand",
981
- /*args=*/(ins "int64_t":$i, "Value":$value),
982
- /*methodBody=*/"",
983
- /*defaultImplementation=*/[{
984
- assert(i >= 0 && i < $_op.getNumOutputs());
985
- auto [start, end] = $_op.getOutputsPositionRange();
986
- $_op->setOperand(start + i, value);
987
- }]
988
- >,
989
- InterfaceMethod<
990
- /*desc=*/"Return the number of inputs.",
991
- /*retTy=*/"int64_t",
992
- /*methodName=*/"getNumInputs",
993
- /*args=*/(ins),
994
- /*methodBody=*/"",
995
- /*defaultImplementation=*/[{
996
- return $_op.getNumOperands() - $_op.getNumOutputs();
997
- }]
998
- >,
999
- InterfaceMethod<
1000
- /*desc=*/"Return the input operands.",
1001
- /*retTy=*/"OpOperandVector",
1002
- /*methodName=*/"getInputOperands",
1003
- /*args=*/(ins),
1004
- /*methodBody=*/"",
1005
- /*defaultImplementation=*/[{
1006
- auto [start, end] = $_op.getOutputsPositionRange();
1007
- int64_t numOutputs = end - start;
1008
- int64_t numOperands = $_op.getNumOperands();
1009
-
1010
- OpOperandVector result;
1011
- result.reserve(numOperands - numOutputs);
1012
- for (int i = 0; i < start; ++i)
1013
- result.push_back(&$_op->getOpOperand(i));
1014
- for (int i = end; i < numOperands; ++i)
1015
- result.push_back(&$_op->getOpOperand(end + i));
1016
-
1017
- return result;
1018
- }]
1019
- >,
1020
- InterfaceMethod<
1021
- /*desc=*/[{ Return the `i`-th input operand. }],
1022
- /*retTy=*/"OpOperand*",
1023
- /*methodName=*/"getInputOperand",
1024
- /*args=*/(ins "int64_t":$i),
1025
- /*methodBody=*/"",
1026
- /*defaultImplementation=*/[{
1027
- assert(i >= 0 && i < getNumInputs());
1028
- auto [start, end] = $_op.getOutputsPositionRange();
1029
- return &$_op->getOpOperand(i < start ? i : i + end - start) ;
1030
- }]
1031
- >,
1032
- //===------------------------------------------------------------------===//
1033
- // Input and Output arguments handling.
1034
- //===------------------------------------------------------------------===//
1035
- InterfaceMethod<
1036
- /*desc=*/"Return true if `opOperand` is an input.",
1037
- /*retTy=*/"bool",
1038
- /*methodName=*/"isInput",
1039
- /*args=*/(ins "OpOperand *":$opOperand),
1040
- /*methodBody=*/"",
1041
- /*defaultImplementation=*/[{
1042
- auto [start, end] = $_op.getOutputsPositionRange();
1043
- auto operandNumber = opOperand->getOperandNumber();
1044
- return operandNumber < start || operandNumber >= end;
1045
- }]
1046
- >,
1047
- InterfaceMethod<
1048
- /*desc=*/"Return true if `opOperand` is an output.",
1049
- /*retTy=*/"bool",
1050
- /*methodName=*/"isOutput",
1051
- /*args=*/(ins "OpOperand *":$opOperand),
1052
- /*methodBody=*/"",
1053
- /*defaultImplementation=*/[{
1054
- auto [start, end] = $_op.getOutputsPositionRange();
1055
- auto operandNumber = opOperand->getOperandNumber();
1056
- return operandNumber >= start && operandNumber < end;
1057
- }]
1058
- >,
1059
- InterfaceMethod<
1060
- /*desc=*/"Return true if the `opOperand` is a scalar value.",
1061
- /*retTy=*/"bool",
1062
- /*methodName=*/"isScalar",
1063
- /*args=*/(ins "OpOperand*":$opOperand),
1064
- /*methodBody=*/"",
1065
- /*defaultImplementation=*/[{
1066
- assert(opOperand->getOwner() == this->getOperation());
1067
- return !opOperand->get().getType().template isa<ShapedType>();
1068
- }]
1069
- >,
1070
- InterfaceMethod<
1071
- /*desc=*/"Return the result tied to `opOperand`.",
1072
- /*retTy=*/"OpResult",
1073
- /*methodName=*/"getTiedOpResult",
1074
- /*args=*/(ins "OpOperand*":$opOperand),
1075
- /*methodBody=*/"",
1076
- /*defaultImplementation=*/[{
1077
- assert(opOperand->getOwner() == this->getOperation());
1078
-
1079
- auto [start, end] = $_op.getOutputsPositionRange();
1080
- int64_t resultIndex = opOperand->getOperandNumber() - start;
1081
- assert(resultIndex >= 0 &&
1082
- resultIndex < $_op->getNumResults() );
1083
- return $_op->getResult(resultIndex);
1084
- }]
1085
- >,
1086
- //===------------------------------------------------------------------===//
1087
- // Other interface methods.
1088
- //===------------------------------------------------------------------===//
1089
- InterfaceMethod<
1090
- /*desc=*/"Return whether the op has only MemRef input and outputs.",
1091
- /*retTy=*/"bool",
1092
- /*methodName=*/"hasBufferSemantics",
1093
- /*args=*/(ins),
1094
- /*methodBody=*/"",
1095
- /*defaultImplementation=*/[{
1096
- return $_op->getNumResults() == 0 &&
1097
- llvm::all_of($_op->getOpOperands(),
1098
- [&](OpOperand &opOperand) {
1099
- return isScalar(&opOperand) ||
1100
- opOperand.get().getType().template isa<MemRefType>();
1101
- });
1102
- }]
1103
- >,
1104
- InterfaceMethod<
1105
- /*desc=*/"Return whether the op has only RankedTensor input and outputs.",
1106
- /*retTy=*/"bool",
1107
- /*methodName=*/"hasTensorSemantics",
1108
- /*args=*/(ins),
1109
- /*methodBody=*/"",
1110
- /*defaultImplementation=*/[{
1111
- return llvm::all_of($_op->getOpOperands(),
1112
- [&](OpOperand &opOperand) {
1113
- return isScalar(&opOperand) ||
1114
- opOperand.get().getType().template isa<RankedTensorType>();
1115
- });
1116
- }]
1117
- >,
1118
- //===------------------------------------------------------------------===//
1119
- // Other static interface methods.
1120
- //===------------------------------------------------------------------===//
1121
- InterfaceMethod<
1122
- /*desc=*/[{
1123
- Clone the current operation with the given location and operands. This
1124
- is used to abstract away the optional underlying region creation. This
1125
- does not change the balance between input, output_buffer and
1126
- init_tensors operands.
1127
- }],
1128
- /*retTy=*/"Operation *",
1129
- /*methodName=*/"clone",
1130
- (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
1131
- "ValueRange":$operands),
1132
- [{
1133
- BlockAndValueMapping bvm;
1134
- OperationState state(
1135
- loc, ConcreteOp::getOperationName(), operands, resultTypes,
1136
- $_op->getAttrs());
1137
- for (Region &r : $_op->getRegions())
1138
- r.cloneInto(state.addRegion(), bvm);
1139
- return b.create(state);
1140
- }]
1141
- >,
1142
- InterfaceMethod<
1143
- /*desc=*/[{
1144
- Clone the current operation with the given location, operands
1145
- and BlockAndValueMapping but leave the regions empty. This is
1146
- used to abstract away the optional underlying region creation.
1147
- This does not change the balance between input, output_buffer
1148
- and init_tensors operands.
1149
- }],
1150
- /*retTy=*/"Operation *",
1151
- /*methodName=*/"cloneWithoutRegions",
1152
- (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
1153
- "ValueRange":$operands),
1154
- [{
1155
- OperationState state(
1156
- loc, ConcreteOp::getOperationName(), operands, resultTypes,
1157
- $_op->getAttrs());
1158
- for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
1159
- state.addRegion();
1160
- return b.create(state);
1161
- }]
1162
- >
1163
- ];
1164
-
1165
- let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
1166
- let verifyWithRegions = 1;
1167
- }
1168
-
1169
882
#endif // LINALG_IR_LINALGINTERFACES
0 commit comments