Skip to content

Commit da2e3c1

Browse files
committed
Add support for frexp. Move vector look up to just callInst and extractValue instruction visits
1 parent 9a5042f commit da2e3c1

File tree

4 files changed

+133
-23
lines changed

4 files changed

+133
-23
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,5 @@ def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32
9191
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
9292
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
9393
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
94-
9594
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
9695
}

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,15 @@ struct VectorLayout {
197197
uint64_t SplitSize = 0;
198198
};
199199

200-
static bool isStructOfVectors(Type *Ty) {
201-
return isa<StructType>(Ty) && Ty->getNumContainedTypes() > 0 &&
202-
isa<FixedVectorType>(Ty->getContainedType(0));
200+
static bool isStructAllVectors(Type *Ty) {
201+
if (!isa<StructType>(Ty))
202+
return false;
203+
204+
for(unsigned I = 0; I < Ty->getNumContainedTypes(); I++)
205+
if (!isa<FixedVectorType>(Ty->getContainedType(I)))
206+
return false;
207+
208+
return true;
203209
}
204210

205211
/// Concatenate the given fragments to a single vector value of the type
@@ -558,10 +564,7 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
558564
// Determine how Ty is split, if at all.
559565
std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
560566
VectorSplit Split;
561-
if (isStructOfVectors(Ty))
562-
Split.VecTy = cast<FixedVectorType>(Ty->getContainedType(0));
563-
else
564-
Split.VecTy = dyn_cast<FixedVectorType>(Ty);
567+
Split.VecTy = dyn_cast<FixedVectorType>(Ty);
565568
if (!Split.VecTy)
566569
return {};
567570

@@ -676,14 +679,24 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
676679
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
677680
if (isTriviallyVectorizable(ID))
678681
return true;
682+
switch (ID) {
683+
case Intrinsic::frexp:
684+
return true;
685+
}
679686
return Intrinsic::isTargetIntrinsic(ID) &&
680687
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
681688
}
682689

683690
/// If a call to a vector typed intrinsic function, split into a scalar call per
684691
/// element if possible for the intrinsic.
685692
bool ScalarizerVisitor::splitCall(CallInst &CI) {
686-
std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
693+
Type* CallType = CI.getType();
694+
bool areAllVectors = isStructAllVectors(CallType);
695+
std::optional<VectorSplit> VS;
696+
if (areAllVectors)
697+
VS = getVectorSplit(CallType->getContainedType(0));
698+
else
699+
VS = getVectorSplit(CallType);
687700
if (!VS)
688701
return false;
689702

@@ -708,6 +721,18 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
708721
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
709722
Tys.push_back(VS->SplitTy);
710723

724+
if(areAllVectors) {
725+
Type* PrevType = CallType->getContainedType(0);
726+
Type* CallType = CI.getType();
727+
for(unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
728+
Type* CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
729+
if(PrevType != CurrType) {
730+
std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
731+
Tys.push_back(CurrVS->SplitTy);
732+
PrevType = CurrType;
733+
}
734+
}
735+
}
711736
// Assumes that any vector type has the same number of elements as the return
712737
// vector type, which is true for all current intrinsics.
713738
for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1043,15 +1068,13 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
10431068
Value *Op = EVI.getOperand(0);
10441069
Type *OpTy = Op->getType();
10451070
ValueVector Res;
1046-
if (!isStructOfVectors(OpTy))
1071+
if (!isStructAllVectors(OpTy))
10471072
return false;
1048-
// Note: isStructOfVectors is also used in getVectorSplit.
1049-
// The intent is to bail on this visit if it isn't a struct
1050-
// of vectors. Downside is that when it is true we do two
1051-
// isStructOfVectors calls.
1052-
std::optional<VectorSplit> VS = getVectorSplit(OpTy);
1073+
Type* VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
1074+
std::optional<VectorSplit> VS = getVectorSplit(VecType);
10531075
if (!VS)
10541076
return false;
1077+
IRBuilder<> Builder(&EVI);
10551078
Scatterer Op0 = scatter(&EVI, Op, *VS);
10561079
assert(!EVI.getIndices().empty() && "Make sure an index exists");
10571080
// Note for our use case we only care about the top level index.
@@ -1252,7 +1275,7 @@ bool ScalarizerVisitor::finish() {
12521275
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
12531276

12541277
// Iterate over each element in the struct
1255-
uint NumOfStructElements = Ty->getNumElements();
1278+
unsigned NumOfStructElements = Ty->getNumElements();
12561279
SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
12571280
for (unsigned I = 0; I < NumOfStructElements; ++I) {
12581281
for (auto *CVelem : CV) {
Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
1+
; RUN: opt -S -scalarizer -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
12

2-
; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
define void @test_vector_double_split_void(<3 x double> noundef %d) {
4+
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
5+
ret void
6+
}
37

4-
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) local_unnamed_addr {
5-
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
6-
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
7-
%2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
8-
%3 = add <3 x i32> %1, %2
9-
ret <3 x i32> %3
8+
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
9+
; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
10+
; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
11+
; CHECK: [[ee1:%.*]] = extractelement <3 x double> %d, i64 1
12+
; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
13+
; CHECK: [[ee2:%.*]] = extractelement <3 x double> %d, i64 2
14+
; CHECK: [[ie2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee2]])
15+
; CHECK: [[ev00:%.*]] = extractvalue { i32, i32 } [[ie0]], 0
16+
; CHECK: [[ev01:%.*]] = extractvalue { i32, i32 } [[ie1]], 0
17+
; CHECK: [[ev02:%.*]] = extractvalue { i32, i32 } [[ie2]], 0
18+
; CHECK: [[ev10:%.*]] = extractvalue { i32, i32 } [[ie0]], 1
19+
; CHECK: [[ev11:%.*]] = extractvalue { i32, i32 } [[ie1]], 1
20+
; CHECK: [[ev12:%.*]] = extractvalue { i32, i32 } [[ie2]], 1
21+
; CHECK: [[add1:%.*]] = add i32 [[ev00]], [[ev10]]
22+
; CHECK: [[add2:%.*]] = add i32 [[ev01]], [[ev11]]
23+
; CHECK: [[add3:%.*]] = add i32 [[ev02]], [[ev12]]
24+
; CHECK: insertelement <3 x i32> poison, i32 [[add1]], i64 0
25+
; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add2]], i64 1
26+
; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add3]], i64 2
27+
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
28+
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
29+
%2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
30+
%3 = add <3 x i32> %1, %2
31+
ret <3 x i32> %3
1032
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
; RUN: opt %s -passes='function(scalarizer<load-store>)' -S | FileCheck %s
2+
3+
; CHECK-LABEL: @test_vector_half_frexp_half
4+
define noundef <2 x half> @test_vector_half_frexp_half(<2 x half> noundef %h) {
5+
; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
6+
; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
7+
; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
8+
; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
9+
; CHECK-NEXT: [[ev00:%.*]] = extractvalue { half, i32 } [[ie0]], 0
10+
; CHECK-NEXT: [[ev01:%.*]] = extractvalue { half, i32 } [[ie1]], 0
11+
; CHECK-NEXT: insertelement <2 x half> poison, half [[ev00]], i64 0
12+
; CHECK-NEXT: insertelement <2 x half> %{{.*}}, half [[ev01]], i64 1
13+
%r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
14+
%e0 = extractvalue { <2 x half>, <2 x i32> } %r, 0
15+
ret <2 x half> %e0
16+
}
17+
18+
; CHECK-LABEL: @test_vector_half_frexp_int
19+
define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
20+
; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
21+
; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
22+
; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
23+
; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
24+
; CHECK-NEXT: [[ev10:%.*]] = extractvalue { half, i32 } [[ie0]], 1
25+
; CHECK-NEXT: [[ev11:%.*]] = extractvalue { half, i32 } [[ie1]], 1
26+
; CHECK-NEXT: insertelement <2 x i32> poison, i32 [[ev10]], i64 0
27+
; CHECK-NEXT: insertelement <2 x i32> %{{.*}}, i32 [[ev11]], i64 1
28+
%r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
29+
%e1 = extractvalue { <2 x half>, <2 x i32> } %r, 1
30+
ret <2 x i32> %e1
31+
}
32+
33+
34+
define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
35+
; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
36+
; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
37+
; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x float> %f, i64 1
38+
; CHECK-NEXT: [[ie1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee1]])
39+
; CHECK-NEXT: [[ev00:%.*]] = extractvalue { float, i32 } [[ie0]], 0
40+
; CHECK-NEXT: [[ev01:%.*]] = extractvalue { float, i32 } [[ie1]], 0
41+
; CHECK-NEXT: insertelement <2 x float> poison, float [[ev00]], i64 0
42+
; CHECK-NEXT: insertelement <2 x float> %{{.*}}, float [[ev01]], i64 1
43+
; CHECK-NEXT: extractvalue { float, i32 } [[ie0]], 1
44+
; CHECK-NEXT: extractvalue { float, i32 } [[ie1]], 1
45+
%1 = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f16.v2i32(<2 x float> %f)
46+
%2 = extractvalue { <2 x float>, <2 x i32> } %1, 0
47+
%3 = extractvalue { <2 x float>, <2 x i32> } %1, 1
48+
ret <2 x float> %2
49+
}
50+
51+
define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
52+
; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
53+
; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])
54+
; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
55+
; CHECK-NEXT: [[ie1:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee1]])
56+
; CHECK-NEXT: [[ev00:%.*]] = extractvalue { double, i32 } [[ie0]], 0
57+
; CHECK-NEXT: [[ev01:%.*]] = extractvalue { double, i32 } [[ie1]], 0
58+
; CHECK-NEXT: insertelement <2 x double> poison, double [[ev00]], i64 0
59+
; CHECK-NEXT: insertelement <2 x double> %{{.*}}, double [[ev01]], i64 1
60+
; CHECK-NEXT: extractvalue { double, i32 } [[ie0]], 1
61+
; CHECK-NEXT: extractvalue { double, i32 } [[ie1]], 1
62+
%1 = call { <2 x double>, <2 x i32> } @llvm.frexp.v2f64.v2i32(<2 x double> %d)
63+
%2 = extractvalue { <2 x double>, <2 x i32> } %1, 0
64+
%3 = extractvalue { <2 x double>, <2 x i32> } %1, 1
65+
ret <2 x double> %2
66+
}

0 commit comments

Comments
 (0)