Skip to content

Commit e80df20

Browse files
committed
fixup! [mlir][SVE] Add more e2e test for vector.contract
Further generalization
1 parent 2558a03 commit e80df20

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,20 @@ func.func @dot_product_i32() {
6464
%vector_b = arith.constant dense<314> : vector<[4]xi32>
6565
%vector_c = arith.constant dense<0> : vector<[4]xi32>
6666

67-
// The result of this dot-product will depend
68-
// on the vector length, so we are unable to verify it.
67+
// DOT PRODUCT 1
6968
%dp1 = vector.contract #dotp_trait %vector_a, %vector_b, %acc
7069
: vector<[4]xi32>, vector<[4]xi32> into i32
71-
// Dot product should be (123 * 314) * 4 * vscale, so ...
70+
// Dot product should be:
71+
// * val = (123 * 314) * 4 * vscale,
72+
// so ...
7273
%vscale = vector.vscale
7374
%vscale_i32 = arith.index_cast %vscale : index to i32
74-
%dp1_divvl = arith.divui %dp1, %vscale_i32 : i32
75-
// ... %dp/%vscale = 123 * 314 * 4 = 154488
75+
%dp1_div = arith.divui %dp1, %vscale_i32 : i32
76+
// ... val / vscale = 123 * 314 * 4 = 154488
7677
// DP: 154488
77-
vector.print %dp1_divvl : i32
78+
vector.print %dp1_div : i32
7879

80+
// DOT PRODUCT 2
7981
// The result of this dot-product should be 0.
8082
%dp2 = vector.contract #dotp_trait %vector_a, %vector_c, %acc
8183
: vector<[4]xi32>, vector<[4]xi32> into i32
@@ -96,18 +98,27 @@ func.func @matvec_i32() {
9698
%vector_b = arith.constant dense<314> : vector<[4]xi32>
9799
%vector_c = arith.constant dense<0> : vector<[4]xi32>
98100

99-
// The result of this matvec will depend on the vector length, so we are
100-
// unable to verify it.
101-
%dp1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
101+
// MATVEC 1
102+
%mv1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
102103
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
103-
// MV: {{[0-9]*}}, {{[0-9]*}}, {{[0-9]*}}
104-
vector.print %dp1 : vector<3xi32>
105-
106-
// The result of this matvc should be a vector of 0s.
107-
%dp2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
104+
// Every element in the output vector is a result of a dot product, for
105+
// which:
106+
// val = (123 * 314) * 4 * vscale
107+
// so ...
108+
%vscale = vector.vscale
109+
%vscale_v = vector.splat %vscale : vector<3xindex>
110+
%vscale_i32 = arith.index_cast %vscale_v : vector<3xindex> to vector<3xi32>
111+
%mv1_div = arith.divui %mv1, %vscale_i32 : vector<3xi32>
112+
// ... val / vscale = 123 * 314 * 4 = 154488
113+
// MV: 154488, 154488, 154488
114+
vector.print %mv1_div : vector<3xi32>
115+
116+
// MATVEC 2
117+
// The result of this matvec should be a vector of 0s.
118+
%mv2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
108119
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
109120
// MV: 0, 0, 0
110-
vector.print %dp2 : vector<3xi32>
121+
vector.print %mv2 : vector<3xi32>
111122

112123
// MV: SVE: END OF TEST OUTPUT
113124
vector.print str "SVE: END OF TEST OUTPUT"

0 commit comments

Comments
 (0)