Skip to content

Commit 12cac48

Browse files
committed
implement countbits correctly
1 parent 5c92f23 commit 12cac48

File tree

6 files changed

+202
-88
lines changed

6 files changed

+202
-88
lines changed

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -705,66 +705,90 @@ float4 cosh(float4);
705705

706706
#ifdef __HLSL_ENABLE_16_BIT
707707
_HLSL_AVAILABILITY(shadermodel, 6.2)
708-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
709-
int16_t countbits(int16_t);
708+
constexpr uint countbits(int16_t x) {
709+
return __builtin_elementwise_popcount(x);
710+
}
710711
_HLSL_AVAILABILITY(shadermodel, 6.2)
711-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
712-
int16_t2 countbits(int16_t2);
712+
constexpr uint2 countbits(int16_t2 x) {
713+
return __builtin_elementwise_popcount(x);
714+
}
713715
_HLSL_AVAILABILITY(shadermodel, 6.2)
714-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
715-
int16_t3 countbits(int16_t3);
716+
constexpr uint3 countbits(int16_t3 x) {
717+
return __builtin_elementwise_popcount(x);
718+
}
716719
_HLSL_AVAILABILITY(shadermodel, 6.2)
717-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
718-
int16_t4 countbits(int16_t4);
720+
constexpr uint4 countbits(int16_t4 x) {
721+
return __builtin_elementwise_popcount(x);
722+
}
719723
_HLSL_AVAILABILITY(shadermodel, 6.2)
720-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
721-
uint16_t countbits(uint16_t);
724+
constexpr uint countbits(uint16_t x) {
725+
return __builtin_elementwise_popcount(x);
726+
}
722727
_HLSL_AVAILABILITY(shadermodel, 6.2)
723-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
724-
uint16_t2 countbits(uint16_t2);
728+
constexpr uint2 countbits(uint16_t2 x) {
729+
return __builtin_elementwise_popcount(x);
730+
}
725731
_HLSL_AVAILABILITY(shadermodel, 6.2)
726-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
727-
uint16_t3 countbits(uint16_t3);
732+
constexpr uint3 countbits(uint16_t3 x) {
733+
return __builtin_elementwise_popcount(x);
734+
}
728735
_HLSL_AVAILABILITY(shadermodel, 6.2)
729-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
730-
uint16_t4 countbits(uint16_t4);
736+
constexpr uint4 countbits(uint16_t4 x) {
737+
return __builtin_elementwise_popcount(x);
738+
}
731739
#endif
732740

733-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
734-
int countbits(int);
735-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
736-
int2 countbits(int2);
737-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
738-
int3 countbits(int3);
739-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
740-
int4 countbits(int4);
741-
742-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
743-
uint countbits(uint);
744-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
745-
uint2 countbits(uint2);
746-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
747-
uint3 countbits(uint3);
748-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
749-
uint4 countbits(uint4);
750-
751-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
752-
int64_t countbits(int64_t);
753-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
754-
int64_t2 countbits(int64_t2);
755-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
756-
int64_t3 countbits(int64_t3);
757-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
758-
int64_t4 countbits(int64_t4);
759-
760-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
761-
uint64_t countbits(uint64_t);
762-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
763-
uint64_t2 countbits(uint64_t2);
764-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
765-
uint64_t3 countbits(uint64_t3);
766-
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
767-
uint64_t4 countbits(uint64_t4);
741+
constexpr uint countbits(int x) {
742+
return __builtin_elementwise_popcount(x);
743+
}
744+
constexpr uint2 countbits(int2 x) {
745+
return __builtin_elementwise_popcount(x);
746+
}
747+
constexpr uint3 countbits(int3 x) {
748+
return __builtin_elementwise_popcount(x);
749+
}
750+
constexpr uint4 countbits(int4 x) {
751+
return __builtin_elementwise_popcount(x);
752+
}
753+
754+
constexpr uint countbits(uint x) {
755+
return __builtin_elementwise_popcount(x);
756+
}
757+
constexpr uint2 countbits(uint2 x) {
758+
return __builtin_elementwise_popcount(x);
759+
}
760+
constexpr uint3 countbits(uint3 x) {
761+
return __builtin_elementwise_popcount(x);
762+
}
763+
constexpr uint4 countbits(uint4 x) {
764+
return __builtin_elementwise_popcount(x);
765+
}
766+
767+
constexpr uint countbits(int64_t x) {
768+
return __builtin_elementwise_popcount(x);
769+
}
770+
constexpr uint2 countbits(int64_t2 x) {
771+
return __builtin_elementwise_popcount(x);
772+
}
773+
constexpr uint3 countbits(int64_t3 x) {
774+
return __builtin_elementwise_popcount(x);
775+
}
776+
constexpr uint4 countbits(int64_t4 x) {
777+
return __builtin_elementwise_popcount(x);
778+
}
779+
780+
constexpr uint countbits(uint64_t x) {
781+
return __builtin_elementwise_popcount(x);
782+
}
783+
constexpr uint2 countbits(uint64_t2 x) {
784+
return __builtin_elementwise_popcount(x);
785+
}
786+
constexpr uint3 countbits(uint64_t3 x) {
787+
return __builtin_elementwise_popcount(x);
788+
}
789+
constexpr uint4 countbits(uint64_t4 x) {
790+
return __builtin_elementwise_popcount(x);
791+
}
768792

769793
//===----------------------------------------------------------------------===//
770794
// degrees builtins

clang/test/CodeGenHLSL/builtins/countbits.hlsl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,38 @@
44

55
#ifdef __HLSL_ENABLE_16_BIT
66
// CHECK-LABEL: test_countbits_ushort
7-
// CHECK: call i16 @llvm.ctpop.i16
8-
uint16_t test_countbits_ushort(uint16_t p0)
7+
// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
8+
// CHECK-NEXT: zext i16 [[A]] to i32
9+
uint test_countbits_ushort(uint16_t p0)
910
{
1011
return countbits(p0);
1112
}
1213
// CHECK-LABEL: test_countbits_ushort2
13-
// CHECK: call <2 x i16> @llvm.ctpop.v2i16
14-
uint16_t2 test_countbits_ushort2(uint16_t2 p0)
14+
// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
15+
// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
16+
uint2 test_countbits_ushort2(uint16_t2 p0)
1517
{
1618
return countbits(p0);
1719
}
1820
// CHECK-LABEL: test_countbits_ushort3
19-
// CHECK: call <3 x i16> @llvm.ctpop.v3i16
20-
uint16_t3 test_countbits_ushort3(uint16_t3 p0)
21+
// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
22+
// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
23+
uint3 test_countbits_ushort3(uint16_t3 p0)
2124
{
2225
return countbits(p0);
2326
}
2427
// CHECK-LABEL: test_countbits_ushort4
25-
// CHECK: call <4 x i16> @llvm.ctpop.v4i16
26-
uint16_t4 test_countbits_ushort4(uint16_t4 p0)
28+
// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
29+
// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
30+
uint4 test_countbits_ushort4(uint16_t4 p0)
2731
{
2832
return countbits(p0);
2933
}
3034
#endif
3135

3236
// CHECK-LABEL: test_countbits_uint
3337
// CHECK: call i32 @llvm.ctpop.i32
34-
int test_countbits_uint(uint p0)
38+
uint test_countbits_uint(uint p0)
3539
{
3640
return countbits(p0);
3741
}
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
5559
}
5660

5761
// CHECK-LABEL: test_countbits_long
58-
// CHECK: call i64 @llvm.ctpop.i64
59-
uint64_t test_countbits_long(uint64_t p0)
62+
// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
63+
// CHECK-NEXT: trunc i64 [[A]] to i32
64+
uint test_countbits_long(uint64_t p0)
6065
{
6166
return countbits(p0);
6267
}
6368
// CHECK-LABEL: test_countbits_long2
64-
// CHECK: call <2 x i64> @llvm.ctpop.v2i64
65-
uint64_t2 test_countbits_long2(uint64_t2 p0)
69+
// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
70+
// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
71+
uint2 test_countbits_long2(uint64_t2 p0)
6672
{
6773
return countbits(p0);
6874
}
6975
// CHECK-LABEL: test_countbits_long3
70-
// CHECK: call <3 x i64> @llvm.ctpop.v3i64
71-
uint64_t3 test_countbits_long3(uint64_t3 p0)
76+
// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
77+
// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
78+
uint3 test_countbits_long3(uint64_t3 p0)
7279
{
7380
return countbits(p0);
7481
}
7582
// CHECK-LABEL: test_countbits_long4
76-
// CHECK: call <4 x i64> @llvm.ctpop.v4i64
77-
uint64_t4 test_countbits_long4(uint64_t4 p0)
83+
// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
84+
// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
85+
uint4 test_countbits_long4(uint64_t4 p0)
7886
{
7987
return countbits(p0);
8088
}
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
// RUN: %clang_cc1 -finclude-default-header
2-
// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
3-
// -disable-llvm-passes -verify -verify-ignore-unexpected
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
42

53

64
double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
97
}
108

119
double2 test_int_builtin_2(double2 p0) {
12-
return __builtin_elementwise_popcount(p0);
13-
// expected-error@-1 {{1st argument must be a vector of integers
14-
// (was 'double2' (aka 'vector<double, 2>'))}}
10+
return countbits(p0);
11+
// expected-error@-1 {{call to 'countbits' is ambiguous}}
1512
}
1613

1714
double test_int_builtin_3(float p0) {
18-
return __builtin_elementwise_popcount(p0);
19-
// expected-error@-1 {{1st argument must be a vector of integers
20-
// (was 'float')}}
15+
return countbits(p0);
16+
// expected-error@-1 {{call to 'countbits' is ambiguous}}
2117
}

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,10 @@ def Rbits : DXILOp<30, unary> {
553553
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
554554
}
555555

556-
def CBits : DXILOp<31, unary> {
556+
def CBits : DXILOp<31, unaryBits> {
557557
let Doc = "Returns the number of 1 bits in the specified value.";
558-
let LLVMIntrinsic = int_ctpop;
559558
let arguments = [OverloadTy];
560-
let result = OverloadTy;
559+
let result = Int32Ty;
561560
let overloads =
562561
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
563562
let stages = [Stages<DXIL1_0, [all_stages]>];

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,67 @@ class OpLowerer {
460460
});
461461
}
462462

463+
[[nodiscard]] bool lowerCtpopToCBits(Function &F) {
464+
IRBuilder<> &IRB = OpBuilder.getIRB();
465+
Type *Int32Ty = IRB.getInt32Ty();
466+
467+
return replaceFunction(F, [&](CallInst *CI) -> Error {
468+
IRB.SetInsertPoint(CI);
469+
SmallVector<Value *> Args;
470+
Args.append(CI->arg_begin(), CI->arg_end());
471+
472+
Type *RetTy = Int32Ty;
473+
Type *FRT = F.getReturnType();
474+
if (FRT->isVectorTy()) {
475+
VectorType *VT = cast<VectorType>(FRT);
476+
RetTy = VectorType::get(RetTy, VT);
477+
}
478+
479+
Expected<CallInst *> OpCall =
480+
OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
481+
if (Error E = OpCall.takeError())
482+
return E;
483+
484+
// If the result type is 32 bits we can do a direct replacement.
485+
if (FRT->isIntOrIntVectorTy(32)) {
486+
CI->replaceAllUsesWith(*OpCall);
487+
CI->eraseFromParent();
488+
return Error::success();
489+
}
490+
491+
unsigned CastOp;
492+
if (FRT->isIntOrIntVectorTy(16))
493+
CastOp = Instruction::ZExt;
494+
else // must be 64 bits
495+
CastOp = Instruction::Trunc;
496+
497+
// It is correct to replace the ctpop with the dxil op and
498+
// remove an existing cast iff the cast is the only usage of
499+
// the ctpop
500+
// can use hasOneUse instead of hasOneUser, because the user
501+
// we care about should have one operand
502+
if (CI->hasOneUse()) {
503+
User *U = CI->user_back();
504+
Instruction *I;
505+
if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
506+
I->getOpcode() == CastOp && I->getType() == RetTy) {
507+
I->replaceAllUsesWith(*OpCall);
508+
I->eraseFromParent();
509+
CI->eraseFromParent();
510+
return Error::success();
511+
}
512+
}
513+
514+
// It is always correct to replace a ctpop with the dxil op and
515+
// a cast
516+
Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
517+
"ctpop.cast");
518+
CI->replaceAllUsesWith(Cast);
519+
CI->eraseFromParent();
520+
return Error::success();
521+
});
522+
}
523+
463524
bool lowerIntrinsics() {
464525
bool Updated = false;
465526
bool HasErrors = false;
@@ -488,6 +549,9 @@ class OpLowerer {
488549
case Intrinsic::dx_typedBufferStore:
489550
HasErrors |= lowerTypedBufferStore(F);
490551
break;
552+
case Intrinsic::ctpop:
553+
HasErrors |= lowerCtpopToCBits(F);
554+
break;
491555
}
492556
Updated = true;
493557
}

0 commit comments

Comments
 (0)