Skip to content

Commit 8c087cf

Browse files
committed
remove use of vectors of length 1 in firstbithigh64 code
1 parent e6cd0d0 commit 8c087cf

File tree

4 files changed

+84
-71
lines changed

4 files changed

+84
-71
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
449449
return Res;
450450
}
451451

452+
Register SPIRVGlobalRegistry::getOrCreateConstScalarOrVector(
453+
uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
454+
const SPIRVInstrInfo &TII, bool ZeroAsNull) {
455+
if (SpvType->getOpcode() == SPIRV::OpTypeVector)
456+
return getOrCreateConstVector(Val, I, SpvType, TII, ZeroAsNull);
457+
else
458+
return getOrCreateConstInt(Val, I, SpvType, TII, ZeroAsNull);
459+
}
460+
452461
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
453462
MachineInstr &I,
454463
SPIRVType *SpvType,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ class SPIRVGlobalRegistry {
492492
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
493493
SPIRVType *SpvType = nullptr);
494494

495+
Register getOrCreateConstScalarOrVector(uint64_t Val, MachineInstr &I,
496+
SPIRVType *SpvType,
497+
const SPIRVInstrInfo &TII,
498+
bool ZeroAsNull = true);
495499
Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
496500
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
497501
bool ZeroAsNull = true);

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2877,82 +2877,82 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
28772877
Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
28782878
Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
28792879

2880-
// 3. check if result of each top 32 bits is == -1
2881-
// split result vector into vector of high bits and vector of low bits
2882-
// get high bits
2883-
// if ResType is a scalar we need a vector anyways because our code
2884-
// operates on vectors, even vectors of length one.
2885-
SPIRVType *VResType = ResType;
2880+
// 3. split result vector into high bits and low bits
2881+
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2882+
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2883+
2884+
bool ZeroAsNull = STI.isOpenCLEnv();
28862885
bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
2887-
if (isScalarRes)
2888-
VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
2889-
// count should be one.
2886+
if (isScalarRes) {
2887+
// if scalar do a vector extract
2888+
Result &= selectNAryOpWithSrcs(
2889+
HighReg, ResType, I,
2890+
{FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
2891+
SPIRV::OpVectorExtractDynamic);
2892+
Result &= selectNAryOpWithSrcs(
2893+
LowReg, ResType, I,
2894+
{FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
2895+
SPIRV::OpVectorExtractDynamic);
2896+
} else { // vector case do a shufflevector
2897+
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2898+
TII.get(SPIRV::OpVectorShuffle))
2899+
.addDef(HighReg)
2900+
.addUse(GR.getSPIRVTypeID(ResType))
2901+
.addUse(FBHReg)
2902+
.addUse(FBHReg);
2903+
// ^^ this vector will not be selected from; could be empty
2904+
unsigned j;
2905+
for (j = 0; j < count * 2; j += 2) {
2906+
MIB.addImm(j);
2907+
}
2908+
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2909+
2910+
// get low bits
2911+
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2912+
TII.get(SPIRV::OpVectorShuffle))
2913+
.addDef(LowReg)
2914+
.addUse(GR.getSPIRVTypeID(ResType))
2915+
.addUse(FBHReg)
2916+
.addUse(FBHReg);
2917+
// ^^ this vector will not be selected from; could be empty
2918+
for (j = 1; j < count * 2; j += 2) {
2919+
MIB.addImm(j);
2920+
}
2921+
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2922+
}
2923+
2924+
// 4. check if result of each top 32 bits is == -1
2925+
SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
2926+
if (!isScalarRes)
2927+
BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
28902928

2891-
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2892-
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2893-
TII.get(SPIRV::OpVectorShuffle))
2894-
.addDef(HighReg)
2895-
.addUse(GR.getSPIRVTypeID(VResType))
2896-
.addUse(FBHReg)
2897-
.addUse(FBHReg);
2898-
// ^^ this vector will not be selected from; could be empty
2899-
unsigned j;
2900-
for (j = 0; j < count * 2; j += 2) {
2901-
MIB.addImm(j);
2902-
}
2903-
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2904-
2905-
// get low bits
2906-
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2907-
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
2908-
TII.get(SPIRV::OpVectorShuffle))
2909-
.addDef(LowReg)
2910-
.addUse(GR.getSPIRVTypeID(VResType))
2911-
.addUse(FBHReg)
2912-
.addUse(FBHReg);
2913-
// ^^ this vector will not be selected from; could be empty
2914-
for (j = 1; j < count * 2; j += 2) {
2915-
MIB.addImm(j);
2916-
}
2917-
Result &= MIB.constrainAllUses(TII, TRI, RBI);
2918-
2919-
SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
2920-
GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
29212929
// check if the high bits are == -1;
2922-
Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
2930+
Register NegOneReg =
2931+
GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
29232932
// true if -1
29242933
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
29252934
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
29262935
SPIRV::OpIEqual);
29272936

29282937
// Select low bits if true in BReg, otherwise high bits
2929-
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2930-
Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
2931-
SPIRV::OpSelectVIVCond);
2938+
unsigned selectOp =
2939+
isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
2940+
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2941+
Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
2942+
selectOp);
29322943

29332944
// Add 32 for high bits, 0 for low bits
2934-
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2935-
bool ZeroAsNull = STI.isOpenCLEnv();
2936-
Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
2937-
Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
2938-
Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
2939-
SPIRV::OpSelectVIVCond);
2940-
2941-
Register AddReg = ResVReg;
2942-
if (isScalarRes)
2943-
AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
2944-
2945-
Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
2946-
SPIRV::OpIAddV);
2947-
2948-
// convert result back to scalar if necessary
2949-
if (!isScalarRes)
2950-
return Result;
2951-
else
2952-
return Result & selectNAryOpWithSrcs(
2953-
ResVReg, ResType, I,
2954-
{AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
2955-
SPIRV::OpVectorExtractDynamic);
2945+
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
2946+
Register Reg0 =
2947+
GR.getOrCreateConstScalarOrVector(0, I, ResType, TII, ZeroAsNull);
2948+
Register Reg32 =
2949+
GR.getOrCreateConstScalarOrVector(32, I, ResType, TII, ZeroAsNull);
2950+
Result &=
2951+
selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
2952+
2953+
return Result &=
2954+
selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg},
2955+
isScalarRes ? SPIRV::OpIAddS : SPIRV::OpIAddV);
29562956
}
29572957

29582958
bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
33

44
; CHECK: OpMemoryModel Logical GLSL450
5+
; CHECK: [[Z:%.*]] = OpConstant %[[#]] 0
6+
; CHECK: [[X:%.*]] = OpConstant %[[#]] 1
57

68
define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
79
entry:
@@ -37,13 +39,12 @@ define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
3739
entry:
3840
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
3941
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
40-
; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
41-
; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
42+
; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
43+
; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
4244
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
4345
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
4446
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
4547
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
46-
; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
4748
%elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
4849
ret i32 %elt.firstbituhigh
4950
}
@@ -82,13 +83,12 @@ define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
8283
entry:
8384
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
8485
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
85-
; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
86-
; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
86+
; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
87+
; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
8788
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
8889
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
8990
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
9091
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
91-
; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
9292
%elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
9393
ret i32 %elt.firstbitshigh
9494
}

0 commit comments

Comments
 (0)