Skip to content

Commit 28fef90

Browse files
committed
[mlir][VectorOps] Fix folding of vector.extract from stretch vector.broadcast
Previously, foldExtractFromBroadcast() would incorrectly fold: func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 { %0 = vector.broadcast %src : vector<3x1x2xf32> to vector<3x4x2xf32> %1 = vector.extract %0[0, 2, 0] : vector<3x4x2xf32> return %1: f32 } to: func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 { %0 = vector.extract %src[0, 2, 0] : vector<3x1x2xf32> return %0: f32 } This was due to the wrong offset being used when zeroing the "dim-1" broadcasted dims. It should use the difference in rank across the broadcast as the starting offset, as the ranks after that are the ones that could have been stretched. Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D157003
1 parent fcb0294 commit 28fef90

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,18 +1467,21 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
14671467
return Value();
14681468

14691469
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1470-
int64_t rankDiff = broadcastSrcRank - extractResultRank;
1470+
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1471+
14711472
// Detect all the positions that come from "dim-1" broadcasting.
14721473
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
14731474
// extract position to `0` when extracting from the source operand.
14741475
llvm::SetVector<int64_t> broadcastedUnitDims =
14751476
broadcastOp.computeBroadcastedUnitDims();
14761477
SmallVector<int64_t> extractPos(extractOp.getPosition());
1477-
for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
1478+
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1479+
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
14781480
if (broadcastedUnitDims.contains(i))
14791481
extractPos[i] = 0;
14801482
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
14811483
// matching extract position when extracting from the source operand.
1484+
int64_t rankDiff = broadcastSrcRank - extractResultRank;
14821485
extractPos.erase(extractPos.begin(),
14831486
std::next(extractPos.begin(), extractPos.size() - rankDiff));
14841487
// OpBuilder is only used as a helper to build an I64ArrayAttr.
@@ -4953,7 +4956,8 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
49534956

49544957
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
49554958
// Eliminate splat constant transpose ops.
4956-
if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
4959+
if (auto attr =
4960+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
49574961
if (attr.isSplat())
49584962
return attr.reshape(getResultVectorType());
49594963

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,15 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
21042104
return %1: vector<1xf32>
21052105
}
21062106

2107+
// CHECK-LABEL: func.func @extract_from_stretch_broadcast
2108+
func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
2109+
// CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : vector<3x1x2xf32>
2110+
// CHECK-NEXT: return %0 : f32
2111+
%0 = vector.broadcast %src : vector<3x1x2xf32> to vector<3x4x2xf32>
2112+
%1 = vector.extract %0[0, 2, 0] : vector<3x4x2xf32>
2113+
return %1: f32
2114+
}
2115+
21072116
// -----
21082117
// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask
21092118
func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{

0 commit comments

Comments
 (0)