Skip to content

Commit 933182e

Browse files
committed
[RISCV] Improve support for forming widening multiplies when one input is a scalar splat.
If one input of a fixed vector multiply is a sign/zero extend and the other operand is a splat of a scalar, we can use a widening multiply if the scalar value has sufficient sign/zero bits. Reviewed By: frasercrmck Differential Revision: https://reviews.llvm.org/D110028
1 parent 1f73f0c commit 933182e

File tree

3 files changed

+567
-39
lines changed

3 files changed

+567
-39
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6576,6 +6576,87 @@ static SDValue performANY_EXTENDCombine(SDNode *N,
65766576
return SDValue(N, 0);
65776577
}
65786578

6579+
// Try to form VWMUL or VWMULU.
6580+
// FIXME: Support VWMULSU.
6581+
static SDValue combineMUL_VLToVWMUL(SDNode *N, SDValue Op0, SDValue Op1,
6582+
SelectionDAG &DAG) {
6583+
assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode");
6584+
bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL;
6585+
bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL;
6586+
if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse())
6587+
return SDValue();
6588+
6589+
SDValue Mask = N->getOperand(2);
6590+
SDValue VL = N->getOperand(3);
6591+
6592+
// Make sure the mask and VL match.
6593+
if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL)
6594+
return SDValue();
6595+
6596+
MVT VT = N->getSimpleValueType(0);
6597+
6598+
// Determine the narrow size for a widening multiply.
6599+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
6600+
MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
6601+
VT.getVectorElementCount());
6602+
6603+
SDLoc DL(N);
6604+
6605+
// See if the other operand is the same opcode.
6606+
if (Op0.getOpcode() == Op1.getOpcode()) {
6607+
if (!Op1.hasOneUse())
6608+
return SDValue();
6609+
6610+
// Make sure the mask and VL match.
6611+
if (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
6612+
return SDValue();
6613+
6614+
Op1 = Op1.getOperand(0);
6615+
} else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) {
6616+
// The operand is a splat of a scalar.
6617+
6618+
// The VL must be the same.
6619+
if (Op1.getOperand(1) != VL)
6620+
return SDValue();
6621+
6622+
// Get the scalar value.
6623+
Op1 = Op1.getOperand(0);
6624+
6625+
// See if have enough sign bits or zero bits in the scalar to use a
6626+
// widening multiply by splatting to smaller element size.
6627+
unsigned EltBits = VT.getScalarSizeInBits();
6628+
unsigned ScalarBits = Op1.getValueSizeInBits();
6629+
// Make sure we're getting all element bits from the scalar register.
6630+
// FIXME: Support implicit sign extension of vmv.v.x?
6631+
if (ScalarBits < EltBits)
6632+
return SDValue();
6633+
6634+
if (IsSignExt) {
6635+
if (DAG.ComputeNumSignBits(Op1) <= (ScalarBits - NarrowSize))
6636+
return SDValue();
6637+
} else {
6638+
APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize);
6639+
if (!DAG.MaskedValueIsZero(Op1, Mask))
6640+
return SDValue();
6641+
}
6642+
6643+
Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, Op1, VL);
6644+
} else
6645+
return SDValue();
6646+
6647+
Op0 = Op0.getOperand(0);
6648+
6649+
// Re-introduce narrower extends if needed.
6650+
unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
6651+
if (Op0.getValueType() != NarrowVT)
6652+
Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL);
6653+
if (Op1.getValueType() != NarrowVT)
6654+
Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL);
6655+
6656+
unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
6657+
return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
6658+
}
6659+
65796660
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
65806661
DAGCombinerInfo &DCI) const {
65816662
SelectionDAG &DAG = DCI.DAG;
@@ -7027,45 +7108,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
70277108
break;
70287109
}
70297110
case RISCVISD::MUL_VL: {
7030-
// Try to form VWMUL or VWMULU.
7031-
// FIXME: Look for splat of extended scalar as well.
7032-
// FIXME: Support VWMULSU.
70337111
SDValue Op0 = N->getOperand(0);
70347112
SDValue Op1 = N->getOperand(1);
7035-
bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL;
7036-
bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL;
7037-
if ((!IsSignExt && !IsZeroExt) || Op0.getOpcode() != Op1.getOpcode())
7038-
return SDValue();
7039-
7040-
// Make sure the extends have a single use.
7041-
if (!Op0.hasOneUse() || !Op1.hasOneUse())
7042-
return SDValue();
7043-
7044-
SDValue Mask = N->getOperand(2);
7045-
SDValue VL = N->getOperand(3);
7046-
if (Op0.getOperand(1) != Mask || Op1.getOperand(1) != Mask ||
7047-
Op0.getOperand(2) != VL || Op1.getOperand(2) != VL)
7048-
return SDValue();
7049-
7050-
Op0 = Op0.getOperand(0);
7051-
Op1 = Op1.getOperand(0);
7052-
7053-
MVT VT = N->getSimpleValueType(0);
7054-
MVT NarrowVT =
7055-
MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits() / 2),
7056-
VT.getVectorElementCount());
7057-
7058-
SDLoc DL(N);
7059-
7060-
// Re-introduce narrower extends if needed.
7061-
unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
7062-
if (Op0.getValueType() != NarrowVT)
7063-
Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL);
7064-
if (Op1.getValueType() != NarrowVT)
7065-
Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL);
7066-
7067-
unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
7068-
return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
7113+
if (SDValue V = combineMUL_VLToVWMUL(N, Op0, Op1, DAG))
7114+
return V;
7115+
if (SDValue V = combineMUL_VLToVWMUL(N, Op1, Op0, DAG))
7116+
return V;
7117+
return SDValue();
70697118
}
70707119
}
70717120

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll

Lines changed: 238 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+experimental-v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
3-
; RUN: llc -mtriple=riscv64 -mattr=+experimental-v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
2+
; RUN: llc -mtriple=riscv32 -mattr=+experimental-v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
3+
; RUN: llc -mtriple=riscv64 -mattr=+experimental-v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
44

55
define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) {
66
; CHECK-LABEL: vwmul_v2i16:
@@ -649,3 +649,239 @@ define <16 x i64> @vwmul_vx_v16i64(<16 x i32>* %x, i32 %y) {
649649
ret <16 x i64> %f
650650
}
651651

652+
define <8 x i16> @vwmul_vx_v8i16_i8(<8 x i8>* %x, i8* %y) {
653+
; CHECK-LABEL: vwmul_vx_v8i16_i8:
654+
; CHECK: # %bb.0:
655+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu
656+
; CHECK-NEXT: vle8.v v25, (a0)
657+
; CHECK-NEXT: lb a0, 0(a1)
658+
; CHECK-NEXT: vwmul.vx v8, v25, a0
659+
; CHECK-NEXT: ret
660+
%a = load <8 x i8>, <8 x i8>* %x
661+
%b = load i8, i8* %y
662+
%c = sext i8 %b to i16
663+
%d = insertelement <8 x i16> undef, i16 %c, i32 0
664+
%e = shufflevector <8 x i16> %d, <8 x i16> undef, <8 x i32> zeroinitializer
665+
%f = sext <8 x i8> %a to <8 x i16>
666+
%g = mul <8 x i16> %e, %f
667+
ret <8 x i16> %g
668+
}
669+
670+
define <8 x i16> @vwmul_vx_v8i16_i16(<8 x i8>* %x, i16* %y) {
671+
; CHECK-LABEL: vwmul_vx_v8i16_i16:
672+
; CHECK: # %bb.0:
673+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu
674+
; CHECK-NEXT: vle8.v v25, (a0)
675+
; CHECK-NEXT: lh a0, 0(a1)
676+
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, mu
677+
; CHECK-NEXT: vsext.vf2 v26, v25
678+
; CHECK-NEXT: vmul.vx v8, v26, a0
679+
; CHECK-NEXT: ret
680+
%a = load <8 x i8>, <8 x i8>* %x
681+
%b = load i16, i16* %y
682+
%d = insertelement <8 x i16> undef, i16 %b, i32 0
683+
%e = shufflevector <8 x i16> %d, <8 x i16> undef, <8 x i32> zeroinitializer
684+
%f = sext <8 x i8> %a to <8 x i16>
685+
%g = mul <8 x i16> %e, %f
686+
ret <8 x i16> %g
687+
}
688+
689+
define <4 x i32> @vwmul_vx_v4i32_i8(<4 x i16>* %x, i8* %y) {
690+
; CHECK-LABEL: vwmul_vx_v4i32_i8:
691+
; CHECK: # %bb.0:
692+
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu
693+
; CHECK-NEXT: vle16.v v25, (a0)
694+
; CHECK-NEXT: lb a0, 0(a1)
695+
; CHECK-NEXT: vwmul.vx v8, v25, a0
696+
; CHECK-NEXT: ret
697+
%a = load <4 x i16>, <4 x i16>* %x
698+
%b = load i8, i8* %y
699+
%c = sext i8 %b to i32
700+
%d = insertelement <4 x i32> undef, i32 %c, i32 0
701+
%e = shufflevector <4 x i32> %d, <4 x i32> undef, <4 x i32> zeroinitializer
702+
%f = sext <4 x i16> %a to <4 x i32>
703+
%g = mul <4 x i32> %e, %f
704+
ret <4 x i32> %g
705+
}
706+
707+
define <4 x i32> @vwmul_vx_v4i32_i16(<4 x i16>* %x, i16* %y) {
708+
; CHECK-LABEL: vwmul_vx_v4i32_i16:
709+
; CHECK: # %bb.0:
710+
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu
711+
; CHECK-NEXT: vle16.v v25, (a0)
712+
; CHECK-NEXT: lh a0, 0(a1)
713+
; CHECK-NEXT: vwmul.vx v8, v25, a0
714+
; CHECK-NEXT: ret
715+
%a = load <4 x i16>, <4 x i16>* %x
716+
%b = load i16, i16* %y
717+
%c = sext i16 %b to i32
718+
%d = insertelement <4 x i32> undef, i32 %c, i32 0
719+
%e = shufflevector <4 x i32> %d, <4 x i32> undef, <4 x i32> zeroinitializer
720+
%f = sext <4 x i16> %a to <4 x i32>
721+
%g = mul <4 x i32> %e, %f
722+
ret <4 x i32> %g
723+
}
724+
725+
define <4 x i32> @vwmul_vx_v4i32_i32(<4 x i16>* %x, i32* %y) {
726+
; CHECK-LABEL: vwmul_vx_v4i32_i32:
727+
; CHECK: # %bb.0:
728+
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu
729+
; CHECK-NEXT: vle16.v v25, (a0)
730+
; CHECK-NEXT: lw a0, 0(a1)
731+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu
732+
; CHECK-NEXT: vsext.vf2 v26, v25
733+
; CHECK-NEXT: vmul.vx v8, v26, a0
734+
; CHECK-NEXT: ret
735+
%a = load <4 x i16>, <4 x i16>* %x
736+
%b = load i32, i32* %y
737+
%d = insertelement <4 x i32> undef, i32 %b, i32 0
738+
%e = shufflevector <4 x i32> %d, <4 x i32> undef, <4 x i32> zeroinitializer
739+
%f = sext <4 x i16> %a to <4 x i32>
740+
%g = mul <4 x i32> %e, %f
741+
ret <4 x i32> %g
742+
}
743+
744+
define <2 x i64> @vwmul_vx_v2i64_i8(<2 x i32>* %x, i8* %y) {
745+
; RV32-LABEL: vwmul_vx_v2i64_i8:
746+
; RV32: # %bb.0:
747+
; RV32-NEXT: addi sp, sp, -16
748+
; RV32-NEXT: .cfi_def_cfa_offset 16
749+
; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
750+
; RV32-NEXT: lb a1, 0(a1)
751+
; RV32-NEXT: vle32.v v25, (a0)
752+
; RV32-NEXT: srai a0, a1, 31
753+
; RV32-NEXT: sw a1, 8(sp)
754+
; RV32-NEXT: sw a0, 12(sp)
755+
; RV32-NEXT: addi a0, sp, 8
756+
; RV32-NEXT: vlse64.v v26, (a0), zero
757+
; RV32-NEXT: vsetvli zero, zero, e64, m1, ta, mu
758+
; RV32-NEXT: vsext.vf2 v27, v25
759+
; RV32-NEXT: vmul.vv v8, v26, v27
760+
; RV32-NEXT: addi sp, sp, 16
761+
; RV32-NEXT: ret
762+
;
763+
; RV64-LABEL: vwmul_vx_v2i64_i8:
764+
; RV64: # %bb.0:
765+
; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
766+
; RV64-NEXT: vle32.v v25, (a0)
767+
; RV64-NEXT: lb a0, 0(a1)
768+
; RV64-NEXT: vwmul.vx v8, v25, a0
769+
; RV64-NEXT: ret
770+
%a = load <2 x i32>, <2 x i32>* %x
771+
%b = load i8, i8* %y
772+
%c = sext i8 %b to i64
773+
%d = insertelement <2 x i64> undef, i64 %c, i64 0
774+
%e = shufflevector <2 x i64> %d, <2 x i64> undef, <2 x i32> zeroinitializer
775+
%f = sext <2 x i32> %a to <2 x i64>
776+
%g = mul <2 x i64> %e, %f
777+
ret <2 x i64> %g
778+
}
779+
780+
define <2 x i64> @vwmul_vx_v2i64_i16(<2 x i32>* %x, i16* %y) {
781+
; RV32-LABEL: vwmul_vx_v2i64_i16:
782+
; RV32: # %bb.0:
783+
; RV32-NEXT: addi sp, sp, -16
784+
; RV32-NEXT: .cfi_def_cfa_offset 16
785+
; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
786+
; RV32-NEXT: lh a1, 0(a1)
787+
; RV32-NEXT: vle32.v v25, (a0)
788+
; RV32-NEXT: srai a0, a1, 31
789+
; RV32-NEXT: sw a1, 8(sp)
790+
; RV32-NEXT: sw a0, 12(sp)
791+
; RV32-NEXT: addi a0, sp, 8
792+
; RV32-NEXT: vlse64.v v26, (a0), zero
793+
; RV32-NEXT: vsetvli zero, zero, e64, m1, ta, mu
794+
; RV32-NEXT: vsext.vf2 v27, v25
795+
; RV32-NEXT: vmul.vv v8, v26, v27
796+
; RV32-NEXT: addi sp, sp, 16
797+
; RV32-NEXT: ret
798+
;
799+
; RV64-LABEL: vwmul_vx_v2i64_i16:
800+
; RV64: # %bb.0:
801+
; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
802+
; RV64-NEXT: vle32.v v25, (a0)
803+
; RV64-NEXT: lh a0, 0(a1)
804+
; RV64-NEXT: vwmul.vx v8, v25, a0
805+
; RV64-NEXT: ret
806+
%a = load <2 x i32>, <2 x i32>* %x
807+
%b = load i16, i16* %y
808+
%c = sext i16 %b to i64
809+
%d = insertelement <2 x i64> undef, i64 %c, i64 0
810+
%e = shufflevector <2 x i64> %d, <2 x i64> undef, <2 x i32> zeroinitializer
811+
%f = sext <2 x i32> %a to <2 x i64>
812+
%g = mul <2 x i64> %e, %f
813+
ret <2 x i64> %g
814+
}
815+
816+
define <2 x i64> @vwmul_vx_v2i64_i32(<2 x i32>* %x, i32* %y) {
817+
; RV32-LABEL: vwmul_vx_v2i64_i32:
818+
; RV32: # %bb.0:
819+
; RV32-NEXT: addi sp, sp, -16
820+
; RV32-NEXT: .cfi_def_cfa_offset 16
821+
; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
822+
; RV32-NEXT: lw a1, 0(a1)
823+
; RV32-NEXT: vle32.v v25, (a0)
824+
; RV32-NEXT: srai a0, a1, 31
825+
; RV32-NEXT: sw a1, 8(sp)
826+
; RV32-NEXT: sw a0, 12(sp)
827+
; RV32-NEXT: addi a0, sp, 8
828+
; RV32-NEXT: vlse64.v v26, (a0), zero
829+
; RV32-NEXT: vsetvli zero, zero, e64, m1, ta, mu
830+
; RV32-NEXT: vsext.vf2 v27, v25
831+
; RV32-NEXT: vmul.vv v8, v26, v27
832+
; RV32-NEXT: addi sp, sp, 16
833+
; RV32-NEXT: ret
834+
;
835+
; RV64-LABEL: vwmul_vx_v2i64_i32:
836+
; RV64: # %bb.0:
837+
; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
838+
; RV64-NEXT: vle32.v v25, (a0)
839+
; RV64-NEXT: lw a0, 0(a1)
840+
; RV64-NEXT: vwmul.vx v8, v25, a0
841+
; RV64-NEXT: ret
842+
%a = load <2 x i32>, <2 x i32>* %x
843+
%b = load i32, i32* %y
844+
%c = sext i32 %b to i64
845+
%d = insertelement <2 x i64> undef, i64 %c, i64 0
846+
%e = shufflevector <2 x i64> %d, <2 x i64> undef, <2 x i32> zeroinitializer
847+
%f = sext <2 x i32> %a to <2 x i64>
848+
%g = mul <2 x i64> %e, %f
849+
ret <2 x i64> %g
850+
}
851+
852+
define <2 x i64> @vwmul_vx_v2i64_i64(<2 x i32>* %x, i64* %y) {
853+
; RV32-LABEL: vwmul_vx_v2i64_i64:
854+
; RV32: # %bb.0:
855+
; RV32-NEXT: addi sp, sp, -16
856+
; RV32-NEXT: .cfi_def_cfa_offset 16
857+
; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
858+
; RV32-NEXT: lw a2, 4(a1)
859+
; RV32-NEXT: lw a1, 0(a1)
860+
; RV32-NEXT: vle32.v v25, (a0)
861+
; RV32-NEXT: sw a2, 12(sp)
862+
; RV32-NEXT: sw a1, 8(sp)
863+
; RV32-NEXT: addi a0, sp, 8
864+
; RV32-NEXT: vlse64.v v26, (a0), zero
865+
; RV32-NEXT: vsetvli zero, zero, e64, m1, ta, mu
866+
; RV32-NEXT: vsext.vf2 v27, v25
867+
; RV32-NEXT: vmul.vv v8, v26, v27
868+
; RV32-NEXT: addi sp, sp, 16
869+
; RV32-NEXT: ret
870+
;
871+
; RV64-LABEL: vwmul_vx_v2i64_i64:
872+
; RV64: # %bb.0:
873+
; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu
874+
; RV64-NEXT: vle32.v v25, (a0)
875+
; RV64-NEXT: ld a0, 0(a1)
876+
; RV64-NEXT: vsetvli zero, zero, e64, m1, ta, mu
877+
; RV64-NEXT: vsext.vf2 v26, v25
878+
; RV64-NEXT: vmul.vx v8, v26, a0
879+
; RV64-NEXT: ret
880+
%a = load <2 x i32>, <2 x i32>* %x
881+
%b = load i64, i64* %y
882+
%d = insertelement <2 x i64> undef, i64 %b, i64 0
883+
%e = shufflevector <2 x i64> %d, <2 x i64> undef, <2 x i32> zeroinitializer
884+
%f = sext <2 x i32> %a to <2 x i64>
885+
%g = mul <2 x i64> %e, %f
886+
ret <2 x i64> %g
887+
}

0 commit comments

Comments
 (0)