Skip to content

[RISCV] Make fixed-point instructions commutable #90372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3132,6 +3132,11 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
case CASE_RVV_OPCODE_WIDEN(VWMACC_VV):
case CASE_RVV_OPCODE_WIDEN(VWMACCU_VV):
case CASE_RVV_OPCODE_UNMASK(VADC_VVM):
case CASE_RVV_OPCODE(VSADD_VV):
case CASE_RVV_OPCODE(VSADDU_VV):
case CASE_RVV_OPCODE(VAADD_VV):
case CASE_RVV_OPCODE(VAADDU_VV):
case CASE_RVV_OPCODE(VSMUL_VV):
// Operands 2 and 3 are commutable.
return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
case CASE_VFMA_SPLATS(FMADD):
Expand Down
29 changes: 16 additions & 13 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -2146,8 +2146,9 @@ multiclass VPseudoBinaryRoundingMode<VReg RetClass,
string Constraint = "",
int sew = 0,
int UsesVXRM = 1,
int TargetConstraintType = 1> {
let VLMul = MInfo.value, SEW=sew in {
int TargetConstraintType = 1,
bit Commutable = 0> {
let VLMul = MInfo.value, SEW=sew, isCommutable = Commutable in {
defvar suffix = !if(sew, "_" # MInfo.MX # "_E" # sew, "_" # MInfo.MX);
def suffix : VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class,
Constraint, UsesVXRM,
Expand Down Expand Up @@ -2232,8 +2233,9 @@ multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0, bi
defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew, Commutable=Commutable>;
}

multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = ""> {
defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint>;
multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = "", bit Commutable = 0> {
defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint,
Commutable=Commutable>;
}

// Similar to VPseudoBinaryV_VV, but uses MxListF.
Expand Down Expand Up @@ -2715,10 +2717,11 @@ multiclass VPseudoVGTR_VV_VX_VI<Operand ImmType = simm5, string Constraint = "">
}
}

multiclass VPseudoVSALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
multiclass VPseudoVSALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = "",
bit Commutable = 0> {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV<m, Constraint>,
defm "" : VPseudoBinaryV_VV<m, Constraint, Commutable=Commutable>,
SchedBinary<"WriteVSALUV", "ReadVSALUV", "ReadVSALUX", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX<m, Constraint>,
Expand Down Expand Up @@ -2788,7 +2791,7 @@ multiclass VPseudoVSALU_VV_VX {
multiclass VPseudoVSMUL_VV_VX_RM {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV_RM<m>,
defm "" : VPseudoBinaryV_VV_RM<m, Commutable=1>,
SchedBinary<"WriteVSMulV", "ReadVSMulV", "ReadVSMulV", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX_RM<m>,
Expand All @@ -2797,10 +2800,10 @@ multiclass VPseudoVSMUL_VV_VX_RM {
}
}

multiclass VPseudoVAALU_VV_VX_RM {
multiclass VPseudoVAALU_VV_VX_RM<bit Commutable = 0> {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV_RM<m>,
defm "" : VPseudoBinaryV_VV_RM<m, Commutable=Commutable>,
SchedBinary<"WriteVAALUV", "ReadVAALUV", "ReadVAALUV", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX_RM<m>,
Expand Down Expand Up @@ -6448,17 +6451,17 @@ defm PseudoVMV_V : VPseudoUnaryVMV_V_X_I;
// 12.1. Vector Single-Width Saturating Add and Subtract
//===----------------------------------------------------------------------===//
let Defs = [VXSAT], hasSideEffects = 1 in {
defm PseudoVSADDU : VPseudoVSALU_VV_VX_VI;
defm PseudoVSADD : VPseudoVSALU_VV_VX_VI;
defm PseudoVSADDU : VPseudoVSALU_VV_VX_VI<Commutable=1>;
defm PseudoVSADD : VPseudoVSALU_VV_VX_VI<Commutable=1>;
defm PseudoVSSUBU : VPseudoVSALU_VV_VX;
defm PseudoVSSUB : VPseudoVSALU_VV_VX;
}

//===----------------------------------------------------------------------===//
// 12.2. Vector Single-Width Averaging Add and Subtract
//===----------------------------------------------------------------------===//
defm PseudoVAADDU : VPseudoVAALU_VV_VX_RM;
defm PseudoVAADD : VPseudoVAALU_VV_VX_RM;
defm PseudoVAADDU : VPseudoVAALU_VV_VX_RM<Commutable=1>;
defm PseudoVAADD : VPseudoVAALU_VV_VX_RM<Commutable=1>;
defm PseudoVASUBU : VPseudoVAALU_VV_VX_RM;
defm PseudoVASUB : VPseudoVAALU_VV_VX_RM;

Expand Down
14 changes: 6 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/commutable.ll
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,9 @@ define <vscale x 1 x i64> @commutable_vaadd_vv(<vscale x 1 x i64> %0, <vscale x
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaadd.vv v10, v8, v9
; CHECK-NEXT: vaadd.vv v8, v9, v8
; CHECK-NEXT: vaadd.vv v8, v8, v9
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vaadd.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen 0, iXLen %2)
Expand All @@ -743,7 +742,7 @@ define <vscale x 1 x i64> @commutable_vaadd_vv_masked(<vscale x 1 x i64> %0, <vs
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaadd.vv v10, v8, v9, v0.t
; CHECK-NEXT: vaadd.vv v8, v9, v8, v0.t
; CHECK-NEXT: vaadd.vv v8, v8, v9, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
Expand All @@ -760,10 +759,9 @@ define <vscale x 1 x i64> @commutable_vaaddu_vv(<vscale x 1 x i64> %0, <vscale x
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaaddu.vv v10, v8, v9
; CHECK-NEXT: vaaddu.vv v8, v9, v8
; CHECK-NEXT: vaaddu.vv v8, v8, v9
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vaaddu.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen 0, iXLen %2)
Expand All @@ -779,7 +777,7 @@ define <vscale x 1 x i64> @commutable_vaaddu_vv_masked(<vscale x 1 x i64> %0, <v
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaaddu.vv v10, v8, v9, v0.t
; CHECK-NEXT: vaaddu.vv v8, v9, v8, v0.t
; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.