Skip to content

Commit 7b8bc1b

Browse files
Revert "[mlir][bufferization] implement BufferizableOpInterface for concat op (#140171)"
This reverts commit 6d9ce67. Multiple builtbot failures have been reported: #140171
1 parent 6d9ce67 commit 7b8bc1b

File tree

3 files changed

+2
-222
lines changed

3 files changed

+2
-222
lines changed

mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
4949
>();
5050
addInterfaces<TensorInlinerInterface>();
5151
declarePromisedInterfaces<
52-
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
53-
DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
52+
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
53+
EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
5454
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
5555
ReshapeOp, SplatOp>();
5656
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,134 +1048,6 @@ struct SplatOpInterface
10481048
}
10491049
};
10501050

1051-
/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052-
/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053-
/// on subviews instead of memref.store.
1054-
struct ConcatOpInterface
1055-
: public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056-
tensor::ConcatOp> {
1057-
1058-
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1059-
1060-
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1061-
const AnalysisState &state) const {
1062-
return false;
1063-
}
1064-
1065-
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1066-
const AnalysisState &state) const {
1067-
return true;
1068-
}
1069-
1070-
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1071-
const AnalysisState &state) const {
1072-
return {};
1073-
}
1074-
1075-
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1076-
const BufferizationOptions &options) const {
1077-
OpBuilder::InsertionGuard g(rewriter);
1078-
auto concatOp = cast<tensor::ConcatOp>(op);
1079-
1080-
// Allocate memory.
1081-
Location loc = op->getLoc();
1082-
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1083-
rewriter, loc, concatOp.getResult(), options,
1084-
/*copy=*/false);
1085-
if (failed(tensorAlloc))
1086-
return failure();
1087-
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1088-
1089-
// TODO: Implement memory space for this op.
1090-
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1091-
return op->emitError("memory space not implemented yet");
1092-
1093-
MemRefLayoutAttrInterface layout;
1094-
MemRefType memrefType =
1095-
MemRefType::get(concatOp.getResultType().getShape(),
1096-
concatOp.getResultType().getElementType(), layout);
1097-
Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
1098-
op->getLoc(), memrefType, *tensorAlloc);
1099-
1100-
// Extract the dimension for the concat op
1101-
uint64_t concatDim = concatOp.getDim();
1102-
bool dynamicConcatDim = false;
1103-
1104-
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1105-
rewriter.getIndexAttr(0));
1106-
SmallVector<OpFoldResult> strides(tensorType.getRank(),
1107-
rewriter.getIndexAttr(1));
1108-
SmallVector<OpFoldResult> sizes;
1109-
1110-
for (const auto &[dimIdx, dimSize] :
1111-
llvm::enumerate(tensorType.getShape())) {
1112-
if (dimSize == ShapedType::kDynamic) {
1113-
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1114-
sizes.push_back(dimOp.getResult());
1115-
if (dimIdx == concatDim)
1116-
dynamicConcatDim = true;
1117-
} else {
1118-
sizes.push_back(rewriter.getIndexAttr(dimSize));
1119-
}
1120-
}
1121-
1122-
int64_t concatDimOffset = 0;
1123-
std::optional<Value> dynamicOffset;
1124-
std::optional<Value> dynamicSize;
1125-
if (dynamicConcatDim) {
1126-
// One or more operands have dynamic size, so we must accumulate the
1127-
// offset with arith ops.
1128-
dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1129-
}
1130-
1131-
for (auto operand : concatOp.getInputs()) {
1132-
// Get the buffer for the operand.
1133-
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
1134-
if (failed(srcBuffer))
1135-
return failure();
1136-
1137-
// Each operand may have a different size along the concat dimension,
1138-
// so the offset on that axis must accumulate through the loop, and the
1139-
// size must change to the size of the current operand.
1140-
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1141-
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1142-
1143-
if (dynamicConcatDim) {
1144-
offsets[concatDim] = dynamicOffset.value();
1145-
dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1146-
.getResult();
1147-
sizes[concatDim] = dynamicSize.value();
1148-
} else {
1149-
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1150-
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1151-
}
1152-
1153-
// Create a subview of the destination buffer.
1154-
auto dstMemrefType = cast<MemRefType>(memrefType);
1155-
MemRefType subviewMemRefType =
1156-
memref::SubViewOp::inferRankReducedResultType(
1157-
operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1158-
strides);
1159-
Value subview = rewriter.create<memref::SubViewOp>(
1160-
loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161-
1162-
// Copy the source buffer into the destination subview.
1163-
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1164-
return failure();
1165-
1166-
if (dynamicConcatDim) {
1167-
dynamicOffset = rewriter.create<arith::AddIOp>(
1168-
loc, dynamicOffset.value(), dynamicSize.value());
1169-
} else {
1170-
concatDimOffset += operandConcatDimSize;
1171-
}
1172-
}
1173-
1174-
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1175-
return success();
1176-
}
1177-
};
1178-
11791051
} // namespace
11801052
} // namespace tensor
11811053
} // namespace mlir
@@ -1185,7 +1057,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
11851057
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
11861058
CastOp::attachInterface<CastOpInterface>(*ctx);
11871059
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188-
ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
11891060
DimOp::attachInterface<DimOpInterface>(*ctx);
11901061
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
11911062
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -615,97 +615,6 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
615615

616616
// -----
617617

618-
// CHECK-LABEL: func @tensor.concat(
619-
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620-
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
621-
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
622-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
623-
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
624-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
625-
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
626-
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
627-
// CHECK: return %[[RET]]
628-
// CHECK: }
629-
func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
630-
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
631-
return %t : tensor<16xf32>
632-
}
633-
634-
// -----
635-
636-
// CHECK-LABEL: func @tensor.concat_different_shapes(
637-
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
638-
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
639-
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
640-
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
641-
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
642-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
643-
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
644-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
645-
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
646-
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
647-
// CHECK: return %[[RET]]
648-
// CHECK: }
649-
func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
650-
%t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
651-
return %t : tensor<8x9xf32>
652-
}
653-
654-
// -----
655-
656-
// CHECK-LABEL: func @tensor.concat_dynamic(
657-
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
658-
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
659-
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
660-
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
661-
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
662-
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
663-
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
664-
// CHECK: %[[ALLOC:.*]] = memref.alloc
665-
// CHECK-SAME: memref<8x?xf32>
666-
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
667-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
668-
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
669-
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
670-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
671-
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
672-
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
673-
// CHECK: return %[[RET]]
674-
// CHECK: }
675-
func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
676-
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
677-
return %t : tensor<8x?xf32>
678-
}
679-
680-
// -----
681-
682-
// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
683-
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
684-
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
685-
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
686-
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
687-
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
688-
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
689-
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
690-
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
691-
// CHECK: %[[ALLOC:.*]] = memref.alloc
692-
// CHECK-SAME: memref<?x?xf32>
693-
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
694-
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
695-
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
696-
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
697-
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
698-
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
699-
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
700-
// CHECK: return %[[RET]]
701-
// CHECK: }
702-
func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
703-
%t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
704-
return %t : tensor<?x?xf32>
705-
}
706-
707-
// -----
708-
709618
// CHECK-LABEL: func @tensor.splat_dynamic(
710619
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
711620
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)