Skip to content

Commit d8b9af7

Browse files
committed
remove use of vectors of length 1 in firstbithigh64 code
1 parent 85a1d4e commit d8b9af7

File tree

4 files changed

+84
-71
lines changed

4 files changed

+84
-71
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
449449
return Res;
450450
}
451451

452+
Register SPIRVGlobalRegistry::getOrCreateConstScalarOrVector(
453+
uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
454+
const SPIRVInstrInfo &TII, bool ZeroAsNull) {
455+
if (SpvType->getOpcode() == SPIRV::OpTypeVector)
456+
return getOrCreateConstVector(Val, I, SpvType, TII, ZeroAsNull);
457+
else
458+
return getOrCreateConstInt(Val, I, SpvType, TII, ZeroAsNull);
459+
}
460+
452461
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
453462
MachineInstr &I,
454463
SPIRVType *SpvType,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ class SPIRVGlobalRegistry {
492492
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
493493
SPIRVType *SpvType = nullptr);
494494

495+
Register getOrCreateConstScalarOrVector(uint64_t Val, MachineInstr &I,
496+
SPIRVType *SpvType,
497+
const SPIRVInstrInfo &TII,
498+
bool ZeroAsNull = true);
495499
Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
496500
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
497501
bool ZeroAsNull = true);

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2759,82 +2759,82 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
27592759
Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
27602760
Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
27612761

2762-
// 3. check if result of each top 32 bits is == -1
2763-
// split result vector into vector of high bits and vector of low bits
2764-
// get high bits
2765-
// if ResType is a scalar we need a vector anyways because our code
2766-
// operates on vectors, even vectors of length one.
2767-
SPIRVType *VResType = ResType;
2762+
// 3. split result vector into high bits and low bits
2763+
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2764+
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2765+
2766+
bool ZeroAsNull = STI.isOpenCLEnv();
27682767
bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
2769-
if (isScalarRes)
2770-
VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
2771-
// count should be one.
2768+
if (isScalarRes) {
2769+
// if scalar do a vector extract
2770+
Result &= selectNAryOpWithSrcs(
2771+
HighReg, ResType, I,
2772+
{FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
2773+
SPIRV::OpVectorExtractDynamic);
2774+
Result &= selectNAryOpWithSrcs(
2775+
LowReg, ResType, I,
2776+
{FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
2777+
SPIRV::OpVectorExtractDynamic);
2778+
} else { // vector case do a shufflevector
2779+
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2780+
TII.get(SPIRV::OpVectorShuffle))
2781+
.addDef(HighReg)
2782+
.addUse(GR.getSPIRVTypeID(ResType))
2783+
.addUse(FBHReg)
2784+
.addUse(FBHReg);
2785+
// ^^ this vector will not be selected from; could be empty
2786+
unsigned j;
2787+
for (j = 0; j < count * 2; j += 2) {
2788+
MIB.addImm(j);
2789+
}
2790+
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2791+
2792+
// get low bits
2793+
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2794+
TII.get(SPIRV::OpVectorShuffle))
2795+
.addDef(LowReg)
2796+
.addUse(GR.getSPIRVTypeID(ResType))
2797+
.addUse(FBHReg)
2798+
.addUse(FBHReg);
2799+
// ^^ this vector will not be selected from; could be empty
2800+
for (j = 1; j < count * 2; j += 2) {
2801+
MIB.addImm(j);
2802+
}
2803+
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2804+
}
2805+
2806+
// 4. check if result of each top 32 bits is == -1
2807+
SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
2808+
if (!isScalarRes)
2809+
BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
27722810

2773-
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2774-
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2775-
TII.get(SPIRV::OpVectorShuffle))
2776-
.addDef(HighReg)
2777-
.addUse(GR.getSPIRVTypeID(VResType))
2778-
.addUse(FBHReg)
2779-
.addUse(FBHReg);
2780-
// ^^ this vector will not be selected from; could be empty
2781-
unsigned j;
2782-
for (j = 0; j < count * 2; j += 2) {
2783-
MIB.addImm(j);
2784-
}
2785-
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2786-
2787-
// get low bits
2788-
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2789-
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2790-
TII.get(SPIRV::OpVectorShuffle))
2791-
.addDef(LowReg)
2792-
.addUse(GR.getSPIRVTypeID(VResType))
2793-
.addUse(FBHReg)
2794-
.addUse(FBHReg);
2795-
// ^^ this vector will not be selected from; could be empty
2796-
for (j = 1; j < count * 2; j += 2) {
2797-
MIB.addImm(j);
2798-
}
2799-
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2800-
2801-
SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
2802-
GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
28032811
// check if the high bits are == -1;
2804-
Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
2812+
Register NegOneReg =
2813+
GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
28052814
// true if -1
28062815
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
28072816
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
28082817
SPIRV::OpIEqual);
28092818

28102819
// Select low bits if true in BReg, otherwise high bits
2811-
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2812-
Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
2813-
SPIRV::OpSelectVIVCond);
2820+
unsigned selectOp =
2821+
isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
2822+
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2823+
Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
2824+
selectOp);
28142825

28152826
// Add 32 for high bits, 0 for low bits
2816-
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2817-
bool ZeroAsNull = STI.isOpenCLEnv();
2818-
Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
2819-
Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
2820-
Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
2821-
SPIRV::OpSelectVIVCond);
2822-
2823-
Register AddReg = ResVReg;
2824-
if (isScalarRes)
2825-
AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2826-
2827-
Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
2828-
SPIRV::OpIAddV);
2829-
2830-
// convert result back to scalar if necessary
2831-
if (!isScalarRes)
2832-
return Result;
2833-
else
2834-
return Result & selectNAryOpWithSrcs(
2835-
ResVReg, ResType, I,
2836-
{AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
2837-
SPIRV::OpVectorExtractDynamic);
2827+
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2828+
Register Reg0 =
2829+
GR.getOrCreateConstScalarOrVector(0, I, ResType, TII, ZeroAsNull);
2830+
Register Reg32 =
2831+
GR.getOrCreateConstScalarOrVector(32, I, ResType, TII, ZeroAsNull);
2832+
Result &=
2833+
selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
2834+
2835+
return Result &=
2836+
selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg},
2837+
isScalarRes ? SPIRV::OpIAddS : SPIRV::OpIAddV);
28382838
}
28392839

28402840
bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
33

44
; CHECK: OpMemoryModel Logical GLSL450
5+
; CHECK: [[Z:%.*]] = OpConstant %[[#]] 0
6+
; CHECK: [[X:%.*]] = OpConstant %[[#]] 1
57

68
define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
79
entry:
@@ -37,13 +39,12 @@ define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
3739
entry:
3840
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
3941
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
40-
; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
41-
; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
42+
; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
43+
; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
4244
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
4345
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
4446
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
4547
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
46-
; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
4748
%elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
4849
ret i32 %elt.firstbituhigh
4950
}
@@ -82,13 +83,12 @@ define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
8283
entry:
8384
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
8485
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
85-
; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
86-
; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
86+
; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
87+
; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
8788
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
8889
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
8990
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
9091
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
91-
; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
9292
%elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
9393
ret i32 %elt.firstbitshigh
9494
}

0 commit comments

Comments
 (0)