Skip to content

Commit eb8ba22

Browse files
committed
[mlir][Vector] Improve dynamic support for vector.extract(broadcast) folders
1 parent 8e61aae commit eb8ba22

File tree

2 files changed

+67
-27
lines changed

2 files changed

+67
-27
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) {
16601660

16611661
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
16621662
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1663-
// TODO: Canonicalization for dynamic position not implemented yet.
1664-
if (extractOp.hasDynamicPosition())
1665-
return Value();
1666-
16671663
Operation *defOp = extractOp.getVector().getDefiningOp();
16681664
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
16691665
return Value();
@@ -1692,6 +1688,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16921688
broadcastVecType.getShape().take_back(extractResultRank))
16931689
return Value();
16941690

1691+
// The dim-1 broadcast -> ExtractOp folder requires in place operation
1692+
// modifications. For dynamic position, this means we have to change the
1693+
// number of operands. This cannot be done in place since it changes the
1694+
// operation storage. For dynamic dimensions, the dim-1 broadcasting should
1695+
// be implemented as a canonicalization pattern.
1696+
// TODO: Implement canonicalization pattern for dim-1 broadcasting +
1697+
// extractop.
1698+
if (extractOp.hasDynamicPosition())
1699+
return Value();
1700+
16951701
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
16961702
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
16971703

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -710,24 +710,44 @@ func.func @fold_extract_transpose(
710710

711711
// -----
712712

713-
// CHECK-LABEL: fold_extract_broadcast
713+
// CHECK-LABEL: fold_extract_broadcast_same_type
714714
// CHECK-SAME: %[[A:.*]]: f32
715715
// CHECK: return %[[A]] : f32
716-
func.func @fold_extract_broadcast(%a : f32) -> f32 {
716+
func.func @fold_extract_broadcast_same_type(%a : f32,
717+
%idx0 : index,
718+
%idx1 : index) -> f32 {
717719
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
718-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
720+
// The indices don't batter for this folder, so we use mixed indices.
721+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
719722
return %r : f32
720723
}
721724

722725
// -----
723726

724-
// CHECK-LABEL: fold_extract_broadcast_0dvec
727+
// CHECK-LABEL: fold_extract_broadcast_same_type_vec
728+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
729+
// CHECK: return %[[A]] : vector<4xf32>
730+
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
731+
%idx0 : index)
732+
-> vector<4xf32> {
733+
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
734+
// The indices don't batter for this folder, so we use mixed indices.
735+
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
736+
return %r : vector<4xf32>
737+
}
738+
739+
// -----
740+
741+
// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
725742
// CHECK-SAME: %[[A:.*]]: vector<f32>
726743
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727744
// CHECK: return %[[B]] : f32
728-
func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
745+
func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>,
746+
%idx0 : index,
747+
%idx1 : index) -> f32 {
729748
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
730-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
749+
// The indices don't batter for this folder, so we use mixed indices.
750+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
731751
return %r : f32
732752
}
733753

@@ -747,57 +767,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747767
// CHECK-LABEL: fold_extract_splat
748768
// CHECK-SAME: %[[A:.*]]: f32
749769
// CHECK: return %[[A]] : f32
750-
func.func @fold_extract_splat(%a : f32) -> f32 {
770+
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
751771
%b = vector.splat %a : vector<1x2x4xf32>
752-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
772+
// The indices don't batter for this folder, so we use mixed indices.
773+
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
753774
return %r : f32
754775
}
755776

756777
// -----
757778

758-
// CHECK-LABEL: fold_extract_broadcast_vector
779+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
759780
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
760-
// CHECK: return %[[A]] : vector<4xf32>
761-
func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
781+
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
782+
// CHECK: return %[[R]] : f32
783+
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
762784
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
763-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
764-
return %r : vector<4xf32>
785+
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
786+
return %r : f32
765787
}
766788

767789
// -----
768790

769-
// CHECK-LABEL: fold_extract_broadcast
791+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
770792
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
771-
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
793+
// CHECK-SAME: %[[IDX:.*]]: index
794+
// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
795+
// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
772796
// CHECK: return %[[R]] : f32
773-
func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
797+
// This folder is not yet implemented. Check that this does not fold.
798+
func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
799+
%a : vector<4xf32>,
800+
%idx : index) -> f32 {
774801
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
775-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
802+
%r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
776803
return %r : f32
777804
}
778805

779806
// -----
780807

781-
// CHECK-LABEL: fold_extract_broadcast
808+
// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
782809
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783810
// CHECK: return %[[B]] : vector<4xf32>
784-
func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
811+
func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32,
812+
%idx0 : index)
813+
-> vector<4xf32> {
785814
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
786-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
815+
// The indices don't batter for this canonicalizer, so we use mixed indices.
816+
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
787817
return %r : vector<4xf32>
788818
}
789819

790820
// -----
791821

792-
// CHECK-LABEL: fold_extract_broadcast
822+
// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
793823
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794824
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795825
// CHECK: return %[[R]] : vector<8xf32>
796-
func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
826+
func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
827+
%idx0 : index)
828+
-> vector<8xf32> {
797829
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
798-
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
830+
// The indices don't batter for this canonicalizer, so we use mixed indices.
831+
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
799832
return %r : vector<8xf32>
800833
}
834+
801835
// -----
802836

803837
// CHECK-LABEL: @fold_extract_shuffle

0 commit comments

Comments
 (0)