Skip to content

Commit f8c63a7

Browse files
committed
[SDAG] Allow scalable vectors in ComputeNumSignBits
This is a continuation of the series of patches adding lane wise support for scalable vectors in various knownbit-esq routines. The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations. Differential Revision: https://reviews.llvm.org/D137141
1 parent 625f08d commit f8c63a7

File tree

7 files changed

+56
-72
lines changed

7 files changed

+56
-72
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3962,11 +3962,10 @@ bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val) const {
39623962
unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const {
39633963
EVT VT = Op.getValueType();
39643964

3965-
// TODO: Assume we don't know anything for now.
3966-
if (VT.isScalableVector())
3967-
return 1;
3968-
3969-
APInt DemandedElts = VT.isVector()
3965+
// Since the number of lanes in a scalable vector is unknown at compile time,
3966+
// we track one bit which is implicitly broadcast to all lanes. This means
3967+
// that all lanes in a scalable vector are considered demanded.
3968+
APInt DemandedElts = VT.isFixedLengthVector()
39703969
? APInt::getAllOnes(VT.getVectorNumElements())
39713970
: APInt(1, 1);
39723971
return ComputeNumSignBits(Op, DemandedElts, Depth);
@@ -3989,7 +3988,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
39893988
if (Depth >= MaxRecursionDepth)
39903989
return 1; // Limit search depth.
39913990

3992-
if (!DemandedElts || VT.isScalableVector())
3991+
if (!DemandedElts)
39933992
return 1; // No demanded elts, better to assume we don't know anything.
39943993

39953994
unsigned Opcode = Op.getOpcode();
@@ -4004,7 +4003,16 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
40044003
case ISD::MERGE_VALUES:
40054004
return ComputeNumSignBits(Op.getOperand(Op.getResNo()), DemandedElts,
40064005
Depth + 1);
4006+
case ISD::SPLAT_VECTOR: {
4007+
// Check if the sign bits of source go down as far as the truncated value.
4008+
unsigned NumSrcBits = Op.getOperand(0).getValueSizeInBits();
4009+
unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
4010+
if (NumSrcSignBits > (NumSrcBits - VTBits))
4011+
return NumSrcSignBits - (NumSrcBits - VTBits);
4012+
break;
4013+
}
40074014
case ISD::BUILD_VECTOR:
4015+
assert(!VT.isScalableVector());
40084016
Tmp = VTBits;
40094017
for (unsigned i = 0, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) {
40104018
if (!DemandedElts[i])
@@ -4049,6 +4057,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
40494057
}
40504058

40514059
case ISD::BITCAST: {
4060+
if (VT.isScalableVector())
4061+
return 1;
40524062
SDValue N0 = Op.getOperand(0);
40534063
EVT SrcVT = N0.getValueType();
40544064
unsigned SrcBits = SrcVT.getScalarSizeInBits();
@@ -4106,6 +4116,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
41064116
Tmp2 = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
41074117
return std::max(Tmp, Tmp2);
41084118
case ISD::SIGN_EXTEND_VECTOR_INREG: {
4119+
if (VT.isScalableVector())
4120+
return 1;
41094121
SDValue Src = Op.getOperand(0);
41104122
EVT SrcVT = Src.getValueType();
41114123
APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements());
@@ -4323,6 +4335,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
43234335
break;
43244336
}
43254337
case ISD::EXTRACT_ELEMENT: {
4338+
if (VT.isScalableVector())
4339+
return 1;
43264340
const int KnownSign = ComputeNumSignBits(Op.getOperand(0), Depth+1);
43274341
const int BitWidth = Op.getValueSizeInBits();
43284342
const int Items = Op.getOperand(0).getValueSizeInBits() / BitWidth;
@@ -4336,6 +4350,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
43364350
return std::clamp(KnownSign - rIndex * BitWidth, 0, BitWidth);
43374351
}
43384352
case ISD::INSERT_VECTOR_ELT: {
4353+
if (VT.isScalableVector())
4354+
return 1;
43394355
// If we know the element index, split the demand between the
43404356
// source vector and the inserted element, otherwise assume we need
43414357
// the original demanded vector elements and the value.
@@ -4366,6 +4382,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
43664382
return Tmp;
43674383
}
43684384
case ISD::EXTRACT_VECTOR_ELT: {
4385+
if (VT.isScalableVector())
4386+
return 1;
43694387
SDValue InVec = Op.getOperand(0);
43704388
SDValue EltNo = Op.getOperand(1);
43714389
EVT VecVT = InVec.getValueType();
@@ -4404,6 +4422,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
44044422
return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1);
44054423
}
44064424
case ISD::CONCAT_VECTORS: {
4425+
if (VT.isScalableVector())
4426+
return 1;
44074427
// Determine the minimum number of sign bits across all demanded
44084428
// elts of the input vectors. Early out if the result is already 1.
44094429
Tmp = std::numeric_limits<unsigned>::max();
@@ -4422,6 +4442,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
44224442
return Tmp;
44234443
}
44244444
case ISD::INSERT_SUBVECTOR: {
4445+
if (VT.isScalableVector())
4446+
return 1;
44254447
// Demand any elements from the subvector and the remainder from the src its
44264448
// inserted into.
44274449
SDValue Src = Op.getOperand(0);
@@ -4492,7 +4514,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
44924514
// We only need to handle vectors - computeKnownBits should handle
44934515
// scalar cases.
44944516
Type *CstTy = Cst->getType();
4495-
if (CstTy->isVectorTy() &&
4517+
if (CstTy->isVectorTy() && !VT.isScalableVector() &&
44964518
(NumElts * VTBits) == CstTy->getPrimitiveSizeInBits() &&
44974519
VTBits == CstTy->getScalarSizeInBits()) {
44984520
Tmp = VTBits;
@@ -4527,6 +4549,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45274549
Opcode == ISD::INTRINSIC_WO_CHAIN ||
45284550
Opcode == ISD::INTRINSIC_W_CHAIN ||
45294551
Opcode == ISD::INTRINSIC_VOID) {
4552+
// TODO: This can probably be removed once target code is audited. This
4553+
// is here purely to reduce patch size and review complexity.
4554+
if (VT.isScalableVector())
4555+
return 1;
45304556
unsigned NumBits =
45314557
TLI->ComputeNumSignBitsForTargetNode(Op, DemandedElts, *this, Depth);
45324558
if (NumBits > 1)

llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ define <vscale x 2 x float> @masked_gather_nxv2f32(float* %base, <vscale x 2 x i
9595
; CHECK: // %bb.0:
9696
; CHECK-NEXT: ptrue p1.d
9797
; CHECK-NEXT: sxth z0.d, p1/m, z0.d
98-
; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, sxtw #2]
98+
; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, lsl #2]
9999
; CHECK-NEXT: ret
100100
%ptrs = getelementptr float, float* %base, <vscale x 2 x i16> %indices
101101
%data = call <vscale x 2 x float> @llvm.masked.gather.nxv2f32(<vscale x 2 x float*> %ptrs, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)

llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@ define <vscale x 2 x i8> @smulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
99
; CHECK-NEXT: ptrue p0.d
1010
; CHECK-NEXT: sxtb z1.d, p0/m, z1.d
1111
; CHECK-NEXT: sxtb z0.d, p0/m, z0.d
12-
; CHECK-NEXT: movprfx z2, z0
13-
; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d
1412
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
15-
; CHECK-NEXT: asr z1.d, z0.d, #63
16-
; CHECK-NEXT: movprfx z3, z0
17-
; CHECK-NEXT: sxtb z3.d, p0/m, z0.d
18-
; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d
19-
; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d
20-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
13+
; CHECK-NEXT: movprfx z1, z0
14+
; CHECK-NEXT: sxtb z1.d, p0/m, z0.d
15+
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d
2116
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
2217
; CHECK-NEXT: ret
2318
%a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
@@ -35,15 +30,10 @@ define <vscale x 4 x i8> @smulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
3530
; CHECK-NEXT: ptrue p0.s
3631
; CHECK-NEXT: sxtb z1.s, p0/m, z1.s
3732
; CHECK-NEXT: sxtb z0.s, p0/m, z0.s
38-
; CHECK-NEXT: movprfx z2, z0
39-
; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s
4033
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
41-
; CHECK-NEXT: asr z1.s, z0.s, #31
42-
; CHECK-NEXT: movprfx z3, z0
43-
; CHECK-NEXT: sxtb z3.s, p0/m, z0.s
44-
; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s
45-
; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s
46-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
34+
; CHECK-NEXT: movprfx z1, z0
35+
; CHECK-NEXT: sxtb z1.s, p0/m, z0.s
36+
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s
4737
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
4838
; CHECK-NEXT: ret
4939
%a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.smul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
@@ -61,15 +51,10 @@ define <vscale x 8 x i8> @smulo_nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %
6151
; CHECK-NEXT: ptrue p0.h
6252
; CHECK-NEXT: sxtb z1.h, p0/m, z1.h
6353
; CHECK-NEXT: sxtb z0.h, p0/m, z0.h
64-
; CHECK-NEXT: movprfx z2, z0
65-
; CHECK-NEXT: smulh z2.h, p0/m, z2.h, z1.h
6654
; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h
67-
; CHECK-NEXT: asr z1.h, z0.h, #15
68-
; CHECK-NEXT: movprfx z3, z0
69-
; CHECK-NEXT: sxtb z3.h, p0/m, z0.h
70-
; CHECK-NEXT: cmpne p1.h, p0/z, z2.h, z1.h
71-
; CHECK-NEXT: cmpne p0.h, p0/z, z3.h, z0.h
72-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
55+
; CHECK-NEXT: movprfx z1, z0
56+
; CHECK-NEXT: sxtb z1.h, p0/m, z0.h
57+
; CHECK-NEXT: cmpne p0.h, p0/z, z1.h, z0.h
7358
; CHECK-NEXT: mov z0.h, p0/m, #0 // =0x0
7459
; CHECK-NEXT: ret
7560
%a = call { <vscale x 8 x i8>, <vscale x 8 x i1> } @llvm.smul.with.overflow.nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y)
@@ -175,15 +160,10 @@ define <vscale x 2 x i16> @smulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
175160
; CHECK-NEXT: ptrue p0.d
176161
; CHECK-NEXT: sxth z1.d, p0/m, z1.d
177162
; CHECK-NEXT: sxth z0.d, p0/m, z0.d
178-
; CHECK-NEXT: movprfx z2, z0
179-
; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d
180163
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
181-
; CHECK-NEXT: asr z1.d, z0.d, #63
182-
; CHECK-NEXT: movprfx z3, z0
183-
; CHECK-NEXT: sxth z3.d, p0/m, z0.d
184-
; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d
185-
; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d
186-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
164+
; CHECK-NEXT: movprfx z1, z0
165+
; CHECK-NEXT: sxth z1.d, p0/m, z0.d
166+
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d
187167
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
188168
; CHECK-NEXT: ret
189169
%a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
@@ -201,15 +181,10 @@ define <vscale x 4 x i16> @smulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
201181
; CHECK-NEXT: ptrue p0.s
202182
; CHECK-NEXT: sxth z1.s, p0/m, z1.s
203183
; CHECK-NEXT: sxth z0.s, p0/m, z0.s
204-
; CHECK-NEXT: movprfx z2, z0
205-
; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s
206184
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
207-
; CHECK-NEXT: asr z1.s, z0.s, #31
208-
; CHECK-NEXT: movprfx z3, z0
209-
; CHECK-NEXT: sxth z3.s, p0/m, z0.s
210-
; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s
211-
; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s
212-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
185+
; CHECK-NEXT: movprfx z1, z0
186+
; CHECK-NEXT: sxth z1.s, p0/m, z0.s
187+
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s
213188
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
214189
; CHECK-NEXT: ret
215190
%a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.smul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
@@ -315,15 +290,10 @@ define <vscale x 2 x i32> @smulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
315290
; CHECK-NEXT: ptrue p0.d
316291
; CHECK-NEXT: sxtw z1.d, p0/m, z1.d
317292
; CHECK-NEXT: sxtw z0.d, p0/m, z0.d
318-
; CHECK-NEXT: movprfx z2, z0
319-
; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d
320293
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
321-
; CHECK-NEXT: asr z1.d, z0.d, #63
322-
; CHECK-NEXT: movprfx z3, z0
323-
; CHECK-NEXT: sxtw z3.d, p0/m, z0.d
324-
; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d
325-
; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d
326-
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
294+
; CHECK-NEXT: movprfx z1, z0
295+
; CHECK-NEXT: sxtw z1.d, p0/m, z0.d
296+
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d
327297
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
328298
; CHECK-NEXT: ret
329299
%a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)

llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vdiv_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
1212
; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma
1313
; CHECK-NEXT: vadd.vv v8, v8, v8
1414
; CHECK-NEXT: vsra.vi v8, v8, 1
15-
; CHECK-NEXT: vmv.v.x v9, a0
16-
; CHECK-NEXT: vadd.vv v9, v9, v9
17-
; CHECK-NEXT: vsra.vi v9, v9, 1
1815
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma
19-
; CHECK-NEXT: vdiv.vv v8, v8, v9, v0.t
16+
; CHECK-NEXT: vdiv.vx v8, v8, a0, v0.t
2017
; CHECK-NEXT: ret
2118
%elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
2219
%vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer

llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vmax_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
1212
; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma
1313
; CHECK-NEXT: vadd.vv v8, v8, v8
1414
; CHECK-NEXT: vsra.vi v8, v8, 1
15-
; CHECK-NEXT: vmv.v.x v9, a0
16-
; CHECK-NEXT: vadd.vv v9, v9, v9
17-
; CHECK-NEXT: vsra.vi v9, v9, 1
1815
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma
19-
; CHECK-NEXT: vmax.vv v8, v8, v9, v0.t
16+
; CHECK-NEXT: vmax.vx v8, v8, a0, v0.t
2017
; CHECK-NEXT: ret
2118
%elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
2219
%vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer

llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vmin_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
1212
; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma
1313
; CHECK-NEXT: vadd.vv v8, v8, v8
1414
; CHECK-NEXT: vsra.vi v8, v8, 1
15-
; CHECK-NEXT: vmv.v.x v9, a0
16-
; CHECK-NEXT: vadd.vv v9, v9, v9
17-
; CHECK-NEXT: vsra.vi v9, v9, 1
1815
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma
19-
; CHECK-NEXT: vmin.vv v8, v8, v9, v0.t
16+
; CHECK-NEXT: vmin.vx v8, v8, a0, v0.t
2017
; CHECK-NEXT: ret
2118
%elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
2219
%vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer

llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vrem_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
1212
; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma
1313
; CHECK-NEXT: vadd.vv v8, v8, v8
1414
; CHECK-NEXT: vsra.vi v8, v8, 1
15-
; CHECK-NEXT: vmv.v.x v9, a0
16-
; CHECK-NEXT: vadd.vv v9, v9, v9
17-
; CHECK-NEXT: vsra.vi v9, v9, 1
1815
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma
19-
; CHECK-NEXT: vrem.vv v8, v8, v9, v0.t
16+
; CHECK-NEXT: vrem.vx v8, v8, a0, v0.t
2017
; CHECK-NEXT: ret
2118
%elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
2219
%vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer

0 commit comments

Comments
 (0)