Skip to content

Commit 8cc932e

Browse files
committed
Set KnownBits to correct width. Reduce 64-bit shl for all vector elts
Signed-off-by: John Lu <[email protected]>
1 parent d8d5f02 commit 8cc932e

File tree

4 files changed

+66
-31
lines changed

4 files changed

+66
-31
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,14 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
430430
mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
431431
ConstantRange Range(Lower->getValue(), Upper->getValue());
432432
unsigned RangeBitWidth = Lower->getBitWidth();
433-
// BitWidth > RangeBitWidth can happen if Known is set to the width of a
434-
// vector load but Ranges describes a vector element.
435-
assert(BitWidth >= RangeBitWidth);
436433

437434
// The first CommonPrefixBits of all values in Range are equal.
438435
unsigned CommonPrefixBits =
439436
(Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
440-
APInt Mask = APInt::getBitsSet(BitWidth, RangeBitWidth - CommonPrefixBits,
441-
RangeBitWidth);
437+
// BitWidth must equal RangeBitWidth. Otherwise Mask will be set
438+
// incorrectly.
439+
assert(BitWidth == RangeBitWidth && "BitWidth must equal RangeBitWidth");
440+
APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
442441
APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth);
443442
Known.One &= UnsignedMax & Mask;
444443
Known.Zero &= ~UnsignedMax & Mask;

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4011,15 +4011,19 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
40114011
// Fill in any known bits from range information. There are 3 types being
40124012
// used. The results VT (same vector elt size as BitWidth), the loaded
40134013
// MemoryVT (which may or may not be vector) and the range VTs original
4014-
// type. The range matadata needs the full range (i.e
4014+
// type. The range metadata needs the full range (i.e
40154015
// MemoryVT().getSizeInBits()), which is truncated to the correct elt size
40164016
// if it is know. These are then extended to the original VT sizes below.
40174017
if (const MDNode *MD = LD->getRanges()) {
4018+
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
4019+
4020+
// FIXME: If loads are modified (e.g. type legalization)
4021+
// so that the load type no longer matches the range metadata type, the
4022+
// range metadata should be updated to match the new load width.
4023+
Known0 = Known0.trunc(Lower->getBitWidth());
40184024
computeKnownBitsFromRangeMetadata(*MD, Known0);
40194025
if (VT.isVector()) {
4020-
// Handle truncation to the first demanded element.
4021-
// TODO: Figure out which demanded elements are covered
4022-
if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
4026+
if (!getDataLayout().isLittleEndian())
40234027
break;
40244028
Known0 = Known0.trunc(BitWidth);
40254029
}

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4084,29 +4084,32 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40844084
}
40854085
}
40864086

4087-
if (VT != MVT::i64)
4087+
if (VT.getScalarType() != MVT::i64)
40884088
return SDValue();
40894089

40904090
// i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
40914091

40924092
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
40934093
// common case, splitting this into a move and a 32-bit shift is faster and
40944094
// the same code size.
4095-
EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
4096-
EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
40974095
KnownBits Known = DAG.computeKnownBits(RHS);
40984096

4099-
if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
4097+
EVT ElementType = VT.getScalarType();
4098+
EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
4099+
EVT TargetType = (VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
4100+
: TargetScalarType);
4101+
4102+
if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
41004103
return SDValue();
41014104
SDValue ShiftAmt;
41024105

41034106
if (CRHS) {
4104-
ShiftAmt =
4105-
DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
4107+
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
4108+
TargetType);
41064109
} else {
41074110
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
41084111
const SDValue ShiftMask =
4109-
DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
4112+
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
41104113
// This AND instruction will clamp out of bounds shift values.
41114114
// It will also be removed during later instruction selection.
41124115
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
@@ -4116,9 +4119,24 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41164119
SDValue NewShift =
41174120
DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());
41184121

4119-
const SDValue Zero = DAG.getConstant(0, SL, TargetType);
4120-
4121-
SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
4122+
const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
4123+
SDValue Vec;
4124+
4125+
if (VT.isVector()) {
4126+
EVT ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
4127+
SmallVector<SDValue, 8> Ops;
4128+
for (unsigned I = 0, E = TargetType.getVectorNumElements(); I != E; ++I) {
4129+
SDValue Index = DAG.getConstant(I, SL, MVT::i32);
4130+
Ops.push_back(Zero);
4131+
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, TargetScalarType,
4132+
NewShift, Index);
4133+
Ops.push_back(Elt);
4134+
}
4135+
Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, Ops);
4136+
} else {
4137+
EVT ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
4138+
Vec = DAG.getBuildVector(ConcatType, SL, {Zero, NewShift});
4139+
}
41224140
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
41234141
}
41244142

@@ -5182,7 +5200,13 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
51825200
break;
51835201
}
51845202
case ISD::SHL: {
5185-
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
5203+
// Range metadata can be invalidated when loads are converted to legal types
5204+
// (e.g. v2i64 -> v4i32).
5205+
// Try to convert vector shl before type legalization so that range metadata
5206+
// can be utilized.
5207+
if (!(N->getValueType(0).isVector() &&
5208+
DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
5209+
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
51865210
break;
51875211

51885212
return performShlCombine(N, DCI);

llvm/test/CodeGen/AMDGPU/shl64_reduce.ll

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ define <2 x i64> @shl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
3434
; CHECK-LABEL: shl_v2_metadata:
3535
; CHECK: ; %bb.0:
3636
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
37-
; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[4:5]
37+
; CHECK-NEXT: flat_load_dwordx4 v[3:6], v[4:5]
3838
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
39-
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v6, v[2:3]
40-
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v4, v0
39+
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v3, v0
40+
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v5, v2
4141
; CHECK-NEXT: v_mov_b32_e32 v0, 0
42+
; CHECK-NEXT: v_mov_b32_e32 v2, 0
4243
; CHECK-NEXT: s_setpc_b64 s[30:31]
4344
%shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0
4445
%shl = shl <2 x i64> %arg0, %shift.amt
@@ -49,12 +50,15 @@ define <3 x i64> @shl_v3_metadata(<3 x i64> %arg0, ptr %arg1.ptr) {
4950
; CHECK-LABEL: shl_v3_metadata:
5051
; CHECK: ; %bb.0:
5152
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
52-
; CHECK-NEXT: flat_load_dword v12, v[6:7] offset:16
53+
; CHECK-NEXT: flat_load_dword v1, v[6:7] offset:16
5354
; CHECK-NEXT: flat_load_dwordx4 v[8:11], v[6:7]
5455
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
55-
; CHECK-NEXT: v_lshlrev_b64 v[4:5], v12, v[4:5]
56-
; CHECK-NEXT: v_lshlrev_b64 v[0:1], v8, v[0:1]
57-
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v10, v[2:3]
56+
; CHECK-NEXT: v_lshlrev_b32_e32 v5, v1, v4
57+
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v8, v0
58+
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v10, v2
59+
; CHECK-NEXT: v_mov_b32_e32 v0, 0
60+
; CHECK-NEXT: v_mov_b32_e32 v2, 0
61+
; CHECK-NEXT: v_mov_b32_e32 v4, 0
5862
; CHECK-NEXT: s_setpc_b64 s[30:31]
5963
%shift.amt = load <3 x i64>, ptr %arg1.ptr, !range !0
6064
%shl = shl <3 x i64> %arg0, %shift.amt
@@ -69,11 +73,15 @@ define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
6973
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
7074
; CHECK-NEXT: flat_load_dwordx4 v[13:16], v[8:9] offset:16
7175
; CHECK-NEXT: ; kill: killed $vgpr8 killed $vgpr9
72-
; CHECK-NEXT: v_lshlrev_b64 v[0:1], v10, v[0:1]
73-
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v12, v[2:3]
76+
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v10, v0
77+
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v12, v2
7478
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
75-
; CHECK-NEXT: v_lshlrev_b64 v[4:5], v13, v[4:5]
76-
; CHECK-NEXT: v_lshlrev_b64 v[6:7], v15, v[6:7]
79+
; CHECK-NEXT: v_lshlrev_b32_e32 v5, v13, v4
80+
; CHECK-NEXT: v_lshlrev_b32_e32 v7, v15, v6
81+
; CHECK-NEXT: v_mov_b32_e32 v0, 0
82+
; CHECK-NEXT: v_mov_b32_e32 v2, 0
83+
; CHECK-NEXT: v_mov_b32_e32 v4, 0
84+
; CHECK-NEXT: v_mov_b32_e32 v6, 0
7785
; CHECK-NEXT: s_setpc_b64 s[30:31]
7886
%shift.amt = load <4 x i64>, ptr %arg1.ptr, !range !0
7987
%shl = shl <4 x i64> %arg0, %shift.amt

0 commit comments

Comments
 (0)