Skip to content

Commit 124c93c

Browse files
committed
[RISCV] When matching SROIW, check all 64 bits of the OR mask
We need to make sure the upper 32 bits are all ones to ensure the result is properly sign extended. Previously we only checked the lower 32 bits of the mask. I've also added a check that the shift amount is less than 32. Without that the original code asserts inside maskLeadingOnes if the SROI check is removed or the SROIW pattern is checked first. I've refactored the code to use early outs to reduce nesting. I've also updated SLOIW matching with the same changes, but I couldn't find a broken test case with the existing code. Differential Revision: https://reviews.llvm.org/D90961
1 parent aeb0fdf commit 124c93c

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -314,62 +314,66 @@ bool RISCVDAGToDAGISel::SelectSLLIUW(SDValue N, SDValue &RS1, SDValue &Shamt) {
314314
// and then we check that VC1, the mask used to fill with ones, is compatible
315315
// with VC2, the shamt:
316316
//
317-
// VC1 == maskTrailingOnes<uint32_t>(VC2)
317+
// VC2 < 32
318+
// VC1 == maskTrailingOnes<uint64_t>(VC2)
318319

319320
bool RISCVDAGToDAGISel::SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt) {
320-
if (Subtarget->getXLenVT() == MVT::i64 &&
321-
N.getOpcode() == ISD::SIGN_EXTEND_INREG &&
322-
cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) {
323-
if (N.getOperand(0).getOpcode() == ISD::OR) {
324-
SDValue Or = N.getOperand(0);
325-
if (Or.getOperand(0).getOpcode() == ISD::SHL) {
326-
SDValue Shl = Or.getOperand(0);
327-
if (isa<ConstantSDNode>(Shl.getOperand(1)) &&
328-
isa<ConstantSDNode>(Or.getOperand(1))) {
329-
uint32_t VC1 = Or.getConstantOperandVal(1);
330-
uint32_t VC2 = Shl.getConstantOperandVal(1);
331-
if (VC1 == maskTrailingOnes<uint32_t>(VC2)) {
332-
RS1 = Shl.getOperand(0);
333-
Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
334-
Shl.getOperand(1).getValueType());
335-
return true;
336-
}
337-
}
338-
}
339-
}
340-
}
341-
return false;
321+
assert(Subtarget->is64Bit() && "SLOIW should only be matched on RV64");
322+
if (N.getOpcode() != ISD::SIGN_EXTEND_INREG ||
323+
cast<VTSDNode>(N.getOperand(1))->getVT() != MVT::i32)
324+
return false;
325+
326+
SDValue Or = N.getOperand(0);
327+
328+
if (Or.getOpcode() != ISD::OR || !isa<ConstantSDNode>(Or.getOperand(1)))
329+
return false;
330+
331+
SDValue Shl = Or.getOperand(0);
332+
if (Shl.getOpcode() != ISD::SHL || !isa<ConstantSDNode>(Shl.getOperand(1)))
333+
return false;
334+
335+
uint64_t VC1 = Or.getConstantOperandVal(1);
336+
uint64_t VC2 = Shl.getConstantOperandVal(1);
337+
338+
if (VC2 >= 32 || VC1 != maskTrailingOnes<uint64_t>(VC2))
339+
return false;
340+
341+
RS1 = Shl.getOperand(0);
342+
Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
343+
Shl.getOperand(1).getValueType());
344+
return true;
342345
}
343346

344347
// Check that it is a SROIW (Shift Right Ones Immediate i32 on RV64).
345348
// We first check that it is the right node tree:
346349
//
347-
// (OR (SHL RS1, VC2), VC1)
350+
// (OR (SRL RS1, VC2), VC1)
348351
//
349352
// and then we check that VC1, the mask used to fill with ones, is compatible
350353
// with VC2, the shamt:
351354
//
352-
// VC1 == maskLeadingOnes<uint32_t>(VC2)
353-
355+
// VC2 < 32
356+
// VC1 == maskTrailingZeros<uint64_t>(32 - VC2)
357+
//
354358
bool RISCVDAGToDAGISel::SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt) {
355-
if (N.getOpcode() == ISD::OR && Subtarget->getXLenVT() == MVT::i64) {
356-
SDValue Or = N;
357-
if (Or.getOperand(0).getOpcode() == ISD::SRL) {
358-
SDValue Srl = Or.getOperand(0);
359-
if (isa<ConstantSDNode>(Srl.getOperand(1)) &&
360-
isa<ConstantSDNode>(Or.getOperand(1))) {
361-
uint32_t VC1 = Or.getConstantOperandVal(1);
362-
uint32_t VC2 = Srl.getConstantOperandVal(1);
363-
if (VC1 == maskLeadingOnes<uint32_t>(VC2)) {
364-
RS1 = Srl.getOperand(0);
365-
Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
366-
Srl.getOperand(1).getValueType());
367-
return true;
368-
}
369-
}
370-
}
371-
}
372-
return false;
359+
assert(Subtarget->is64Bit() && "SROIW should only be matched on RV64");
360+
if (N.getOpcode() != ISD::OR || !isa<ConstantSDNode>(N.getOperand(1)))
361+
return false;
362+
363+
SDValue Srl = N.getOperand(0);
364+
if (Srl.getOpcode() != ISD::SRL || !isa<ConstantSDNode>(Srl.getOperand(1)))
365+
return false;
366+
367+
uint64_t VC1 = N.getConstantOperandVal(1);
368+
uint64_t VC2 = Srl.getConstantOperandVal(1);
369+
370+
if (VC2 >= 32 || VC1 != maskTrailingZeros<uint64_t>(32 - VC2))
371+
return false;
372+
373+
RS1 = Srl.getOperand(0);
374+
Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
375+
Srl.getOperand(1).getValueType());
376+
return true;
373377
}
374378

375379
// Check that it is a RORIW (i32 Right Rotate Immediate on RV64).

llvm/test/CodeGen/RISCV/rv64Zbb.ll

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ define signext i32 @sroi_i32(i32 signext %a) nounwind {
166166
; This is similar to the type legalized version of sroiw but the mask is 0 in
167167
; the upper bits instead of 1 so the result is not sign extended. Make sure we
168168
; don't match it to sroiw.
169-
; FIXME: We're matching it to sroiw.
170169
define i64 @sroiw_bug(i64 %a) nounwind {
171170
; RV64I-LABEL: sroiw_bug:
172171
; RV64I: # %bb.0:
@@ -178,12 +177,18 @@ define i64 @sroiw_bug(i64 %a) nounwind {
178177
;
179178
; RV64IB-LABEL: sroiw_bug:
180179
; RV64IB: # %bb.0:
181-
; RV64IB-NEXT: sroiw a0, a0, 1
180+
; RV64IB-NEXT: srli a0, a0, 1
181+
; RV64IB-NEXT: addi a1, zero, 1
182+
; RV64IB-NEXT: slli a1, a1, 31
183+
; RV64IB-NEXT: or a0, a0, a1
182184
; RV64IB-NEXT: ret
183185
;
184186
; RV64IBB-LABEL: sroiw_bug:
185187
; RV64IBB: # %bb.0:
186-
; RV64IBB-NEXT: sroiw a0, a0, 1
188+
; RV64IBB-NEXT: srli a0, a0, 1
189+
; RV64IBB-NEXT: addi a1, zero, 1
190+
; RV64IBB-NEXT: slli a1, a1, 31
191+
; RV64IBB-NEXT: or a0, a0, a1
187192
; RV64IBB-NEXT: ret
188193
%neg = lshr i64 %a, 1
189194
%neg12 = or i64 %neg, 2147483648

0 commit comments

Comments
 (0)