Skip to content

[RISCV] Don't fold vmerge.vvm or vmv.v.v into vredsum.vs if AVL changed #99006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3753,11 +3753,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (!Info)
return false;

// When Mask is not a true mask, this transformation is illegal for some
// operations whose results are affected by mask, like viota.m.
if (Info->MaskAffectsResult && Mask && !usesAllOnesMask(Mask, Glue))
return false;

// If True has a merge operand then it needs to be the same as vmerge's False,
// since False will be used for the result's merge operand.
if (HasTiedDest && !isImplicitDef(True->getOperand(0))) {
Expand Down Expand Up @@ -3835,6 +3830,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (!VL)
return false;

// Some operations produce different elementwise results depending on the
// active elements, like viota.m or vredsum. This transformation is illegal
// for these if we change the active elements (i.e. mask or VL).
if (Info->ActiveElementsAffectResult) {
if (Mask && !usesAllOnesMask(Mask, Glue))
return false;
if (TrueVL != VL)
return false;
}

// If we end up changing the VL or mask of True, then we need to make sure it
// doesn't raise any observable fp exceptions, since changing the active
// elements will affect how fflags is set.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ struct RISCVMaskedPseudoInfo {
uint16_t MaskedPseudo;
uint16_t UnmaskedPseudo;
uint8_t MaskOpIdx;
uint8_t MaskAffectsResult : 1;
uint8_t ActiveElementsAffectResult : 1;
};
#define GET_RISCVMaskedPseudosTable_DECL
#include "RISCVGenSearchableTables.inc"
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -561,17 +561,17 @@ def RISCVVIntrinsicsTable : GenericTable {
// unmasked variant. For all but compares, both the masked and
// unmasked variant have a passthru and policy operand. For compares,
// neither has a policy op, and only the masked version has a passthru.
class RISCVMaskedPseudo<bits<4> MaskIdx, bit MaskAffectsRes=false> {
class RISCVMaskedPseudo<bits<4> MaskIdx, bit ActiveAffectsRes=false> {
Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
bits<4> MaskOpIdx = MaskIdx;
bit MaskAffectsResult = MaskAffectsRes;
bit ActiveElementsAffectResult = ActiveAffectsRes;
}

def RISCVMaskedPseudosTable : GenericTable {
let FilterClass = "RISCVMaskedPseudo";
let CppTypeName = "RISCVMaskedPseudoInfo";
let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx", "MaskAffectsResult"];
let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx", "ActiveElementsAffectResult"];
let PrimaryKey = ["MaskedPseudo"];
let PrimaryKeyName = "getMaskedPseudoInfo";
}
Expand Down Expand Up @@ -2065,7 +2065,7 @@ multiclass VPseudoVIOTA_M {
SchedUnary<"WriteVIotaV", "ReadVIotaV", mx,
forceMergeOpRead=true>;
def "_" # mx # "_MASK" : VPseudoUnaryMask<m.vrclass, VR, constraint>,
RISCVMaskedPseudo<MaskIdx=2, MaskAffectsRes=true>,
RISCVMaskedPseudo<MaskIdx=2, ActiveAffectsRes=true>,
SchedUnary<"WriteVIotaV", "ReadVIotaV", mx,
forceMergeOpRead=true>;
}
Expand Down Expand Up @@ -3162,7 +3162,7 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
defvar mx = MInfo.MX;
def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class>;
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class>,
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
RISCVMaskedPseudo<MaskIdx=3, ActiveAffectsRes=true>;
}
}

Expand All @@ -3179,7 +3179,7 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
def "_" # mx # "_E" # sew # "_MASK"
: VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
Op2Class>,
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
RISCVMaskedPseudo<MaskIdx=3, ActiveAffectsRes=true>;
}
}

Expand Down
54 changes: 52 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/combine-vmv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ define <vscale x 4 x i32> @vadd(<vscale x 4 x i32> %passthru, <vscale x 4 x i32>
ret <vscale x 4 x i32> %w
}

define <vscale x 4 x i32> @vadd_mask(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i1> %m, iXLen %vl) {
; CHECK-LABEL: vadd_mask:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m2, tu, mu
; CHECK-NEXT: vadd.vv v8, v10, v12, v0.t
; CHECK-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.mask.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i1> %m, iXLen %vl, iXLen 3)
%w = call <vscale x 4 x i32> @llvm.riscv.vmv.v.v.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, iXLen %vl)
ret <vscale x 4 x i32> %w
}

define <vscale x 4 x i32> @vadd_undef(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: vadd_undef:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -106,8 +117,8 @@ declare <vscale x 4 x float> @llvm.riscv.vmv.v.v.nxv4f32(<vscale x 4 x float>, <

declare <vscale x 4 x float> @llvm.riscv.vfadd.nxv4f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, iXLen, iXLen)

define <vscale x 4 x float> @vfadd(<vscale x 4 x float> %passthru, <vscale x 4 x float> %a, <vscale x 4 x float> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: vfadd:
define <vscale x 4 x float> @unfoldable_vfadd(<vscale x 4 x float> %passthru, <vscale x 4 x float> %a, <vscale x 4 x float> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: unfoldable_vfadd:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m2, ta, ma
; CHECK-NEXT: vfadd.vv v10, v10, v12
Expand All @@ -118,3 +129,42 @@ define <vscale x 4 x float> @vfadd(<vscale x 4 x float> %passthru, <vscale x 4 x
%w = call <vscale x 4 x float> @llvm.riscv.vmv.v.v.nxv4f32(<vscale x 4 x float> %passthru, <vscale x 4 x float> %v, iXLen %vl2)
ret <vscale x 4 x float> %w
}

define <vscale x 4 x float> @foldable_vfadd(<vscale x 4 x float> %passthru, <vscale x 4 x float> %a, <vscale x 4 x float> %b, iXLen %vl) {
; CHECK-LABEL: foldable_vfadd:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m2, tu, ma
; CHECK-NEXT: vfadd.vv v8, v10, v12
; CHECK-NEXT: ret
%v = call <vscale x 4 x float> @llvm.riscv.vfadd.nxv4f32.nxv4f32(<vscale x 4 x float> poison, <vscale x 4 x float> %a, <vscale x 4 x float> %b, iXLen 7, iXLen %vl)
%w = call <vscale x 4 x float> @llvm.riscv.vmv.v.v.nxv4f32(<vscale x 4 x float> %passthru, <vscale x 4 x float> %v, iXLen %vl)
ret <vscale x 4 x float> %w
}

; This shouldn't be folded because we need to preserve exceptions with
; "fpexcept.strict" exception behaviour, and changing the VL may hide them.
define <vscale x 4 x float> @unfoldable_constrained_fadd(<vscale x 4 x float> %passthru, <vscale x 4 x float> %x, <vscale x 4 x float> %y, iXLen %vl) strictfp {
; CHECK-LABEL: unfoldable_constrained_fadd:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, ma
; CHECK-NEXT: vfadd.vv v10, v10, v12
; CHECK-NEXT: vsetvli zero, a0, e32, m2, tu, ma
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%a = call <vscale x 4 x float> @llvm.experimental.constrained.fadd(<vscale x 4 x float> %x, <vscale x 4 x float> %y, metadata !"round.dynamic", metadata !"fpexcept.strict") strictfp
%b = call <vscale x 4 x float> @llvm.riscv.vmv.v.v.nxv4f32(<vscale x 4 x float> %passthru, <vscale x 4 x float> %a, iXLen %vl) strictfp
ret <vscale x 4 x float> %b
}

define <vscale x 2 x i32> @unfoldable_vredsum(<vscale x 2 x i32> %passthru, <vscale x 4 x i32> %x, <vscale x 2 x i32> %y) {
; CHECK-LABEL: unfoldable_vredsum:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
; CHECK-NEXT: vredsum.vs v9, v10, v9
; CHECK-NEXT: vsetivli zero, 1, e32, m1, tu, ma
; CHECK-NEXT: vmv.v.v v8, v9
; CHECK-NEXT: ret
%a = call <vscale x 2 x i32> @llvm.riscv.vredsum.nxv2i32.nxv4i32(<vscale x 2 x i32> poison, <vscale x 4 x i32> %x, <vscale x 2 x i32> %y, iXLen -1)
%b = call <vscale x 2 x i32> @llvm.riscv.vmv.v.v.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, iXLen 1)
ret <vscale x 2 x i32> %b
}
18 changes: 18 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,24 @@ define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passth
ret <vscale x 2 x float> %b
}

define <vscale x 2 x i32> @unfoldable_vredsum_allones_mask_diff_vl(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y) {
; CHECK-LABEL: unfoldable_vredsum_allones_mask_diff_vl:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v11, v8
; CHECK-NEXT: vsetvli a0, zero, e32, m1, tu, ma
; CHECK-NEXT: vredsum.vs v11, v9, v10
; CHECK-NEXT: vsetivli zero, 1, e32, m1, tu, ma
; CHECK-NEXT: vmv.v.v v8, v11
; CHECK-NEXT: ret
%a = call <vscale x 2 x i32> @llvm.riscv.vredsum.nxv2i32.nxv2i32(
<vscale x 2 x i32> %passthru,
<vscale x 2 x i32> %x,
<vscale x 2 x i32> %y,
i64 -1)
%b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 1)
ret <vscale x 2 x i32> %b
}

declare <vscale x 32 x i16> @llvm.riscv.vle.nxv32i16.i64(<vscale x 32 x i16>, ptr nocapture, i64)
declare <vscale x 32 x i8> @llvm.riscv.vssubu.mask.nxv32i8.i8.i64(<vscale x 32 x i8>, <vscale x 32 x i8>, i8, <vscale x 32 x i1>, i64, i64 immarg)
declare <vscale x 32 x i1> @llvm.riscv.vmseq.nxv32i8.nxv32i8.i64(<vscale x 32 x i8>, <vscale x 32 x i8>, i64)
Expand Down
Loading