Skip to content

Commit 443cdd0

Browse files
authored
[RISCV] Fix a bug in partial.reduce lowering for zvqdotq .vx forms (#142185)
I'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type. This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes. We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI.
1 parent 6a6aec6 commit 443cdd0

File tree

3 files changed

+77
-15
lines changed

3 files changed

+77
-15
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8412,13 +8412,18 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
84128412
assert(ArgVT == B.getSimpleValueType() &&
84138413
ArgVT.getVectorElementType() == MVT::i8);
84148414

8415+
// The zvqdotq pseudos are defined with sources and destination both
8416+
// being i32. This cast is needed for correctness to avoid incorrect
8417+
// .vx matching of i8 splats.
8418+
A = DAG.getBitcast(VT, A);
8419+
B = DAG.getBitcast(VT, B);
8420+
84158421
MVT ContainerVT = VT;
84168422
if (VT.isFixedLengthVector()) {
84178423
ContainerVT = getContainerForFixedLengthVector(VT);
84188424
Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
8419-
MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
8420-
A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
8421-
B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
8425+
A = convertToScalableVector(ContainerVT, A, DAG, Subtarget);
8426+
B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
84228427
}
84238428

84248429
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,6 @@ entry:
598598
ret <1 x i32> %res
599599
}
600600

601-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
602601
define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
603602
; NODOT-LABEL: vqdotu_vx_partial_reduce:
604603
; NODOT: # %bb.0: # %entry
@@ -618,10 +617,13 @@ define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
618617
;
619618
; DOT-LABEL: vqdotu_vx_partial_reduce:
620619
; DOT: # %bb.0: # %entry
621-
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
620+
; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
622621
; DOT-NEXT: vmv.s.x v9, zero
623622
; DOT-NEXT: li a0, 128
624-
; DOT-NEXT: vqdotu.vx v9, v8, a0
623+
; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
624+
; DOT-NEXT: vmv.v.x v10, a0
625+
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
626+
; DOT-NEXT: vqdotu.vv v9, v8, v10
625627
; DOT-NEXT: vmv1r.v v8, v9
626628
; DOT-NEXT: ret
627629
entry:
@@ -631,7 +633,6 @@ entry:
631633
ret <1 x i32> %res
632634
}
633635

634-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
635636
define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
636637
; NODOT-LABEL: vqdot_vx_partial_reduce:
637638
; NODOT: # %bb.0: # %entry
@@ -652,10 +653,13 @@ define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
652653
;
653654
; DOT-LABEL: vqdot_vx_partial_reduce:
654655
; DOT: # %bb.0: # %entry
655-
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
656+
; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
656657
; DOT-NEXT: vmv.s.x v9, zero
657658
; DOT-NEXT: li a0, 128
658-
; DOT-NEXT: vqdot.vx v9, v8, a0
659+
; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
660+
; DOT-NEXT: vmv.v.x v10, a0
661+
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
662+
; DOT-NEXT: vqdot.vv v9, v8, v10
659663
; DOT-NEXT: vmv1r.v v8, v9
660664
; DOT-NEXT: ret
661665
entry:
@@ -1372,7 +1376,6 @@ entry:
13721376
}
13731377

13741378

1375-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
13761379
define <4 x i32> @partial_of_sext(<16 x i8> %a) {
13771380
; NODOT-LABEL: partial_of_sext:
13781381
; NODOT: # %bb.0: # %entry
@@ -1393,10 +1396,11 @@ define <4 x i32> @partial_of_sext(<16 x i8> %a) {
13931396
;
13941397
; DOT-LABEL: partial_of_sext:
13951398
; DOT: # %bb.0: # %entry
1399+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
1400+
; DOT-NEXT: vmv.v.i v10, 1
13961401
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13971402
; DOT-NEXT: vmv.v.i v9, 0
1398-
; DOT-NEXT: li a0, 1
1399-
; DOT-NEXT: vqdot.vx v9, v8, a0
1403+
; DOT-NEXT: vqdot.vv v9, v8, v10
14001404
; DOT-NEXT: vmv.v.v v8, v9
14011405
; DOT-NEXT: ret
14021406
entry:
@@ -1405,7 +1409,6 @@ entry:
14051409
ret <4 x i32> %res
14061410
}
14071411

1408-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
14091412
define <4 x i32> @partial_of_zext(<16 x i8> %a) {
14101413
; NODOT-LABEL: partial_of_zext:
14111414
; NODOT: # %bb.0: # %entry
@@ -1426,10 +1429,11 @@ define <4 x i32> @partial_of_zext(<16 x i8> %a) {
14261429
;
14271430
; DOT-LABEL: partial_of_zext:
14281431
; DOT: # %bb.0: # %entry
1432+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
1433+
; DOT-NEXT: vmv.v.i v10, 1
14291434
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
14301435
; DOT-NEXT: vmv.v.i v9, 0
1431-
; DOT-NEXT: li a0, 1
1432-
; DOT-NEXT: vqdotu.vx v9, v8, a0
1436+
; DOT-NEXT: vqdotu.vv v9, v8, v10
14331437
; DOT-NEXT: vmv.v.v v8, v9
14341438
; DOT-NEXT: ret
14351439
entry:

llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,56 @@ entry:
957957
%res = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 1 x i32> zeroinitializer, <vscale x 4 x i32> %mul)
958958
ret <vscale x 1 x i32> %res
959959
}
960+
961+
962+
define <vscale x 4 x i32> @partial_of_sext(<vscale x 16 x i8> %a) {
963+
; NODOT-LABEL: partial_of_sext:
964+
; NODOT: # %bb.0: # %entry
965+
; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
966+
; NODOT-NEXT: vsext.vf4 v16, v8
967+
; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
968+
; NODOT-NEXT: vadd.vv v8, v22, v16
969+
; NODOT-NEXT: vadd.vv v10, v18, v20
970+
; NODOT-NEXT: vadd.vv v8, v10, v8
971+
; NODOT-NEXT: ret
972+
;
973+
; DOT-LABEL: partial_of_sext:
974+
; DOT: # %bb.0: # %entry
975+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
976+
; DOT-NEXT: vmv.v.i v12, 1
977+
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
978+
; DOT-NEXT: vmv.v.i v10, 0
979+
; DOT-NEXT: vqdot.vv v10, v8, v12
980+
; DOT-NEXT: vmv.v.v v8, v10
981+
; DOT-NEXT: ret
982+
entry:
983+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
984+
%res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
985+
ret <vscale x 4 x i32> %res
986+
}
987+
988+
define <vscale x 4 x i32> @partial_of_zext(<vscale x 16 x i8> %a) {
989+
; NODOT-LABEL: partial_of_zext:
990+
; NODOT: # %bb.0: # %entry
991+
; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
992+
; NODOT-NEXT: vzext.vf4 v16, v8
993+
; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
994+
; NODOT-NEXT: vadd.vv v8, v22, v16
995+
; NODOT-NEXT: vadd.vv v10, v18, v20
996+
; NODOT-NEXT: vadd.vv v8, v10, v8
997+
; NODOT-NEXT: ret
998+
;
999+
; DOT-LABEL: partial_of_zext:
1000+
; DOT: # %bb.0: # %entry
1001+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
1002+
; DOT-NEXT: vmv.v.i v12, 1
1003+
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
1004+
; DOT-NEXT: vmv.v.i v10, 0
1005+
; DOT-NEXT: vqdotu.vv v10, v8, v12
1006+
; DOT-NEXT: vmv.v.v v8, v10
1007+
; DOT-NEXT: ret
1008+
entry:
1009+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1010+
%res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
1011+
ret <vscale x 4 x i32> %res
1012+
}

0 commit comments

Comments
 (0)