Skip to content

Commit 9af8047

Browse files
committed
Check the definition of mask operand (i.e. V0)
1 parent 18c4a08 commit 9af8047

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,10 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
16751675
if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
16761676
return false;
16771677

1678+
assert(MI1.getMF() == MI2.getMF());
1679+
const MachineRegisterInfo *MRI = &MI1.getMF()->getRegInfo();
1680+
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
1681+
16781682
// Make sure vtype operands are also the same.
16791683
const MCInstrDesc &Desc = get(MI1.getOpcode());
16801684
const uint64_t TSFlags = Desc.TSFlags;
@@ -1697,10 +1701,49 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
16971701
return false;
16981702

16991703
// Mask
1700-
// There might be more sophisticated ways to check equality of masks, but
1701-
// right now we simply check if they're the same virtual register.
1702-
if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
1703-
return false;
1704+
if (RISCVII::usesMaskPolicy(TSFlags)) {
1705+
const MachineBasicBlock *MBB = MI1.getParent();
1706+
const MachineBasicBlock::const_reverse_iterator It1(&MI1);
1707+
const MachineBasicBlock::const_reverse_iterator It2(&MI2);
1708+
Register MI1VReg;
1709+
1710+
bool SeenMI2 = false;
1711+
for (auto End = MBB->rend(), It = It1; It != End; ++It) {
1712+
if (It == It2) {
1713+
SeenMI2 = true;
1714+
if (!MI1VReg.isValid())
1715+
// There is no V0 def between MI1 and MI2; they're sharing the
1716+
// same V0.
1717+
break;
1718+
}
1719+
1720+
if (It->definesRegister(RISCV::V0, TRI)) {
1721+
Register SrcReg =
1722+
TRI->lookThruCopyLike(It->getOperand(1).getReg(), MRI);
1723+
1724+
if (!MI1VReg.isValid()) {
1725+
// This is the V0 def for MI1.
1726+
MI1VReg = SrcReg;
1727+
continue;
1728+
}
1729+
1730+
// Some random mask updates.
1731+
if (!SeenMI2)
1732+
continue;
1733+
1734+
// This is the V0 def for MI2; check if it's the same as that of
1735+
// MI1.
1736+
if (MI1VReg != SrcReg)
1737+
return false;
1738+
else
1739+
break;
1740+
}
1741+
}
1742+
1743+
// If we haven't encountered MI2, it's likely that this function was
1744+
// called in a wrong way (e.g. MI1 is before MI2).
1745+
assert(SeenMI2 && "MI2 is expected to appear before MI1");
1746+
}
17041747

17051748
// Tail / Mask policies
17061749
if (RISCVII::hasVecPolicyOp(TSFlags) &&

llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,16 @@ entry:
215215
ret <vscale x 1 x i8> %c
216216
}
217217

218-
define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m) nounwind {
218+
define <vscale x 1 x i8> @vadd_vv_mask_negative(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i32 %2, <vscale x 1 x i1> %m, <vscale x 1 x i1> %m2) nounwind {
219219
; CHECK-LABEL: vadd_vv_mask_negative:
220220
; CHECK: # %bb.0: # %entry
221221
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu
222-
; CHECK-NEXT: vmv1r.v v10, v8
223-
; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t
222+
; CHECK-NEXT: vmv1r.v v11, v8
223+
; CHECK-NEXT: vadd.vv v11, v8, v9, v0.t
224224
; CHECK-NEXT: vmv1r.v v9, v8
225-
; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t
226-
; CHECK-NEXT: vadd.vv v8, v8, v9
225+
; CHECK-NEXT: vadd.vv v9, v8, v11, v0.t
226+
; CHECK-NEXT: vmv1r.v v0, v10
227+
; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t
227228
; CHECK-NEXT: ret
228229
entry:
229230
%a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
@@ -240,8 +241,6 @@ entry:
240241
<vscale x 1 x i1> %m,
241242
i32 %2, i32 1)
242243

243-
%splat = insertelement <vscale x 1 x i1> poison, i1 1, i32 0
244-
%m2 = shufflevector <vscale x 1 x i1> %splat, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
245244
%c = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(
246245
<vscale x 1 x i8> %0,
247246
<vscale x 1 x i8> %0,

0 commit comments

Comments
 (0)