Skip to content

Commit 67dc952

Browse files
committed
Allow dim-1 broadcasting + dynamic indices
1 parent 5f172d8 commit 67dc952

File tree

2 files changed

+9
-33
lines changed

2 files changed

+9
-33
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,16 +1688,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16881688
broadcastVecType.getShape().take_back(extractResultRank))
16891689
return Value();
16901690

1691-
// The dim-1 broadcast -> ExtractOp folder requires in-place operation
1692-
// modifications. For dynamic position, this means we have to change the
1693-
// number of operands. This cannot be done in place since it changes the
1694-
// operation storage. For dynamic dimensions, the dim-1 broadcasting should
1695-
// be implemented as a canonicalization pattern.
1696-
// TODO: Implement canonicalization pattern for dim-1 broadcasting +
1697-
// extractop.
1698-
if (extractOp.hasDynamicPosition())
1699-
return Value();
1700-
17011691
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
17021692
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
17031693

@@ -1706,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17061696
// extract position to `0` when extracting from the source operand.
17071697
llvm::SetVector<int64_t> broadcastedUnitDims =
17081698
broadcastOp.computeBroadcastedUnitDims();
1709-
SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1699+
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1700+
OpBuilder b(extractOp.getContext());
17101701
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
17111702
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
17121703
if (broadcastedUnitDims.contains(i))
1713-
extractPos[i] = 0;
1704+
extractPos[i] = b.getIndexAttr(0);
17141705
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
17151706
// matching extract position when extracting from the source operand.
17161707
int64_t rankDiff = broadcastSrcRank - extractResultRank;
17171708
extractPos.erase(extractPos.begin(),
17181709
std::next(extractPos.begin(), extractPos.size() - rankDiff));
17191710
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1720-
OpBuilder b(extractOp.getContext());
1721-
extractOp.setOperand(0, source);
1722-
extractOp.setStaticPosition(extractPos);
1711+
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1712+
extractOp->setOperands(
1713+
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1714+
extractOp.setStaticPosition(staticPos);
17231715
return extractOp.getResult();
17241716
}
17251717

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -779,25 +779,9 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
779779
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
780780
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
781781
// CHECK: return %[[R]] : f32
782-
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
783-
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
784-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
785-
return %r : f32
786-
}
787-
788-
// -----
789-
790-
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_negative
791-
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
792-
// CHECK-SAME: %[[IDX:.*]]: index
793-
// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
794-
// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
795-
// CHECK: return %[[R]] : f32
796-
// This folder is not yet implemented. Check that this does not fold.
797-
func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_negative(
798-
%a : vector<4xf32>,
799-
%idx : index) -> f32 {
782+
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>, %idx : index) -> f32 {
800783
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
784+
// The indices don't matter for this folder, so we use mixed indices.
801785
%r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
802786
return %r : f32
803787
}

0 commit comments

Comments
 (0)