@@ -1048,134 +1048,6 @@ struct SplatOpInterface
1048
1048
}
1049
1049
};
1050
1050
1051
- // / Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052
- // / filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053
- // / on subviews instead of memref.store.
1054
- struct ConcatOpInterface
1055
- : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056
- tensor::ConcatOp> {
1057
-
1058
- bool bufferizesToAllocation (Operation *op, Value value) const { return true ; }
1059
-
1060
- bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1061
- const AnalysisState &state) const {
1062
- return false ;
1063
- }
1064
-
1065
- bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
1066
- const AnalysisState &state) const {
1067
- return true ;
1068
- }
1069
-
1070
- AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
1071
- const AnalysisState &state) const {
1072
- return {};
1073
- }
1074
-
1075
- LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
1076
- const BufferizationOptions &options) const {
1077
- OpBuilder::InsertionGuard g (rewriter);
1078
- auto concatOp = cast<tensor::ConcatOp>(op);
1079
-
1080
- // Allocate memory.
1081
- Location loc = op->getLoc ();
1082
- FailureOr<Value> tensorAlloc = allocateTensorForShapedValue (
1083
- rewriter, loc, concatOp.getResult (), options,
1084
- /* copy=*/ false );
1085
- if (failed (tensorAlloc))
1086
- return failure ();
1087
- auto tensorType = cast<RankedTensorType>(tensorAlloc->getType ());
1088
-
1089
- // TODO: Implement memory space for this op.
1090
- if (options.defaultMemorySpaceFn (tensorType) != Attribute ())
1091
- return op->emitError (" memory space not implemented yet" );
1092
-
1093
- MemRefLayoutAttrInterface layout;
1094
- MemRefType memrefType =
1095
- MemRefType::get (concatOp.getResultType ().getShape (),
1096
- concatOp.getResultType ().getElementType (), layout);
1097
- Value dstBuffer = rewriter.create <bufferization::ToMemrefOp>(
1098
- op->getLoc (), memrefType, *tensorAlloc);
1099
-
1100
- // Extract the dimension for the concat op
1101
- uint64_t concatDim = concatOp.getDim ();
1102
- bool dynamicConcatDim = false ;
1103
-
1104
- SmallVector<OpFoldResult> offsets (tensorType.getRank (),
1105
- rewriter.getIndexAttr (0 ));
1106
- SmallVector<OpFoldResult> strides (tensorType.getRank (),
1107
- rewriter.getIndexAttr (1 ));
1108
- SmallVector<OpFoldResult> sizes;
1109
-
1110
- for (const auto &[dimIdx, dimSize] :
1111
- llvm::enumerate (tensorType.getShape ())) {
1112
- if (dimSize == ShapedType::kDynamic ) {
1113
- auto dimOp = rewriter.create <memref::DimOp>(loc, dstBuffer, dimIdx);
1114
- sizes.push_back (dimOp.getResult ());
1115
- if (dimIdx == concatDim)
1116
- dynamicConcatDim = true ;
1117
- } else {
1118
- sizes.push_back (rewriter.getIndexAttr (dimSize));
1119
- }
1120
- }
1121
-
1122
- int64_t concatDimOffset = 0 ;
1123
- std::optional<Value> dynamicOffset;
1124
- std::optional<Value> dynamicSize;
1125
- if (dynamicConcatDim) {
1126
- // One or more operands have dynamic size, so we must accumulate the
1127
- // offset with arith ops.
1128
- dynamicOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1129
- }
1130
-
1131
- for (auto operand : concatOp.getInputs ()) {
1132
- // Get the buffer for the operand.
1133
- FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
1134
- if (failed (srcBuffer))
1135
- return failure ();
1136
-
1137
- // Each operand may have a different size along the concat dimension,
1138
- // so the offset on that axis must accumulate through the loop, and the
1139
- // size must change to the size of the current operand.
1140
- auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1141
- int64_t operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1142
-
1143
- if (dynamicConcatDim) {
1144
- offsets[concatDim] = dynamicOffset.value ();
1145
- dynamicSize = rewriter.create <memref::DimOp>(loc, *srcBuffer, concatDim)
1146
- .getResult ();
1147
- sizes[concatDim] = dynamicSize.value ();
1148
- } else {
1149
- sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1150
- offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1151
- }
1152
-
1153
- // Create a subview of the destination buffer.
1154
- auto dstMemrefType = cast<MemRefType>(memrefType);
1155
- MemRefType subviewMemRefType =
1156
- memref::SubViewOp::inferRankReducedResultType (
1157
- operandTensorType.getShape (), dstMemrefType, offsets, sizes,
1158
- strides);
1159
- Value subview = rewriter.create <memref::SubViewOp>(
1160
- loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161
-
1162
- // Copy the source buffer into the destination subview.
1163
- if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
1164
- return failure ();
1165
-
1166
- if (dynamicConcatDim) {
1167
- dynamicOffset = rewriter.create <arith::AddIOp>(
1168
- loc, dynamicOffset.value (), dynamicSize.value ());
1169
- } else {
1170
- concatDimOffset += operandConcatDimSize;
1171
- }
1172
- }
1173
-
1174
- replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
1175
- return success ();
1176
- }
1177
- };
1178
-
1179
1051
} // namespace
1180
1052
} // namespace tensor
1181
1053
} // namespace mlir
@@ -1185,7 +1057,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1185
1057
registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1186
1058
CastOp::attachInterface<CastOpInterface>(*ctx);
1187
1059
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188
- ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1189
1060
DimOp::attachInterface<DimOpInterface>(*ctx);
1190
1061
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1191
1062
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
0 commit comments