Skip to content

Commit 9a5042f

Browse files
committed
[Scalarizer] A change to let the scalarizer pass be able to scalarize structs
1 parent 0afe6e4 commit 9a5042f

File tree

4 files changed

+119
-40
lines changed

4 files changed

+119
-40
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,8 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
8989
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
9090
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
9191
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
92+
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
93+
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
94+
9295
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
9396
}
Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,39 @@
1-
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2-
//-*-===//
3-
//
4-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5-
// See https://llvm.org/LICENSE.txt for license information.
6-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7-
//
8-
//===----------------------------------------------------------------------===//
9-
///
10-
//===----------------------------------------------------------------------===//
11-
12-
#include "DirectXTargetTransformInfo.h"
13-
#include "llvm/IR/Intrinsics.h"
14-
#include "llvm/IR/IntrinsicsDirectX.h"
15-
16-
using namespace llvm;
17-
18-
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19-
unsigned ScalarOpdIdx) {
20-
switch (ID) {
21-
case Intrinsic::dx_wave_readlane:
22-
return ScalarOpdIdx == 1;
23-
default:
24-
return false;
25-
}
26-
}
27-
28-
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
29-
Intrinsic::ID ID) const {
30-
switch (ID) {
31-
case Intrinsic::dx_frac:
32-
case Intrinsic::dx_rsqrt:
33-
case Intrinsic::dx_wave_readlane:
34-
return true;
35-
default:
36-
return false;
37-
}
38-
}
1+
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "DirectXTargetTransformInfo.h"
13+
#include "llvm/IR/Intrinsics.h"
14+
#include "llvm/IR/IntrinsicsDirectX.h"
15+
16+
using namespace llvm;
17+
18+
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19+
unsigned ScalarOpdIdx) {
20+
switch (ID) {
21+
case Intrinsic::dx_wave_readlane:
22+
return ScalarOpdIdx == 1;
23+
default:
24+
return false;
25+
}
26+
}
27+
28+
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
29+
Intrinsic::ID ID) const {
30+
switch (ID) {
31+
case Intrinsic::dx_frac:
32+
case Intrinsic::dx_rsqrt:
33+
case Intrinsic::dx_wave_readlane:
34+
case Intrinsic::dx_splitdouble:
35+
return true;
36+
default:
37+
return false;
38+
}
39+
}

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ struct VectorLayout {
197197
uint64_t SplitSize = 0;
198198
};
199199

200+
static bool isStructOfVectors(Type *Ty) {
201+
return isa<StructType>(Ty) && Ty->getNumContainedTypes() > 0 &&
202+
isa<FixedVectorType>(Ty->getContainedType(0));
203+
}
204+
200205
/// Concatenate the given fragments to a single vector value of the type
201206
/// described in @p VS.
202207
static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +281,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276281
bool visitBitCastInst(BitCastInst &BCI);
277282
bool visitInsertElementInst(InsertElementInst &IEI);
278283
bool visitExtractElementInst(ExtractElementInst &EEI);
284+
bool visitExtractValueInst(ExtractValueInst &EVI);
279285
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
280286
bool visitPHINode(PHINode &PHI);
281287
bool visitLoadInst(LoadInst &LI);
@@ -552,7 +558,10 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
552558
// Determine how Ty is split, if at all.
553559
std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
554560
VectorSplit Split;
555-
Split.VecTy = dyn_cast<FixedVectorType>(Ty);
561+
if (isStructOfVectors(Ty))
562+
Split.VecTy = cast<FixedVectorType>(Ty->getContainedType(0));
563+
else
564+
Split.VecTy = dyn_cast<FixedVectorType>(Ty);
556565
if (!Split.VecTy)
557566
return {};
558567

@@ -1030,6 +1039,33 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
10301039
return true;
10311040
}
10321041

1042+
bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1043+
Value *Op = EVI.getOperand(0);
1044+
Type *OpTy = Op->getType();
1045+
ValueVector Res;
1046+
if (!isStructOfVectors(OpTy))
1047+
return false;
1048+
// Note: isStructOfVectors is also used in getVectorSplit.
1049+
// The intent is to bail on this visit if it isn't a struct
1050+
// of vectors. Downside is that when it is true we do two
1051+
// isStructOfVectors calls.
1052+
std::optional<VectorSplit> VS = getVectorSplit(OpTy);
1053+
if (!VS)
1054+
return false;
1055+
Scatterer Op0 = scatter(&EVI, Op, *VS);
1056+
assert(!EVI.getIndices().empty() && "Make sure an index exists");
1057+
// Note for our use case we only care about the top level index.
1058+
unsigned Index = EVI.getIndices()[0];
1059+
for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
1060+
Value *ResElem = Builder.CreateExtractValue(
1061+
Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
1062+
Res.push_back(ResElem);
1063+
}
1064+
// replaceUses(&EVI, Res);
1065+
gather(&EVI, Res, *VS);
1066+
return true;
1067+
}
1068+
10331069
bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
10341070
std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
10351071
if (!VS)
@@ -1196,7 +1232,7 @@ bool ScalarizerVisitor::finish() {
11961232
if (!Op->use_empty()) {
11971233
// The value is still needed, so recreate it using a series of
11981234
// insertelements and/or shufflevectors.
1199-
Value *Res;
1235+
Value *Res = nullptr;
12001236
if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
12011237
BasicBlock *BB = Op->getParent();
12021238
IRBuilder<> Builder(Op);
@@ -1209,6 +1245,35 @@ bool ScalarizerVisitor::finish() {
12091245
Res = concatenate(Builder, CV, VS, Op->getName());
12101246

12111247
Res->takeName(Op);
1248+
} else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
1249+
BasicBlock *BB = Op->getParent();
1250+
IRBuilder<> Builder(Op);
1251+
if (isa<PHINode>(Op))
1252+
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1253+
1254+
// Iterate over each element in the struct
1255+
uint NumOfStructElements = Ty->getNumElements();
1256+
SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
1257+
for (unsigned I = 0; I < NumOfStructElements; ++I) {
1258+
for (auto *CVelem : CV) {
1259+
Value *Elem = Builder.CreateExtractValue(
1260+
CVelem, I, Op->getName() + ".elem" + std::to_string(I));
1261+
ElemCV[I].push_back(Elem);
1262+
}
1263+
}
1264+
Res = PoisonValue::get(Ty);
1265+
for (unsigned I = 0; I < NumOfStructElements; ++I) {
1266+
Type *ElemTy = Ty->getElementType(I);
1267+
assert(isa<FixedVectorType>(ElemTy) &&
1268+
"Only Structs of all FixedVectorType supported");
1269+
VectorSplit VS = *getVectorSplit(ElemTy);
1270+
assert(VS.NumFragments == CV.size());
1271+
1272+
Value *ConcatenatedVector =
1273+
concatenate(Builder, ElemCV[I], VS, Op->getName());
1274+
Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
1275+
Op->getName() + ".insert");
1276+
}
12121277
} else {
12131278
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
12141279
Res = CV[0];
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) local_unnamed_addr {
5+
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
6+
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
7+
%2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
8+
%3 = add <3 x i32> %1, %2
9+
ret <3 x i32> %3
10+
}

0 commit comments

Comments
 (0)