Skip to content

Commit c03a0f1

Browse files
committed
instead of detecting the conflicts lets pick the smallest value for the alloca then keep the cast but change the input type.
1 parent de3024d commit c03a0f1

File tree

2 files changed

+116
-58
lines changed

2 files changed

+116
-58
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,23 @@ static void fixI8UseChain(Instruction &I,
145145
}
146146

147147
if (auto *Cast = dyn_cast<CastInst>(&I)) {
148-
149-
if (Cast->getSrcTy()->isIntegerTy(8)) {
150-
ToRemove.push_back(Cast);
151-
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
148+
if (!Cast->getSrcTy()->isIntegerTy(8))
149+
return;
150+
151+
ToRemove.push_back(Cast);
152+
auto* Replacement =ReplacedValues[Cast->getOperand(0)];
153+
if (Cast->getType() == Replacement->getType()) {
154+
Cast->replaceAllUsesWith(Replacement);
155+
return;
152156
}
157+
Value* AdjustedCast = nullptr;
158+
if (Cast->getOpcode() == Instruction::ZExt)
159+
AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
160+
if (Cast->getOpcode() == Instruction::SExt)
161+
AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
162+
163+
if(AdjustedCast)
164+
Cast->replaceAllUsesWith(AdjustedCast);
153165
}
154166
}
155167

@@ -160,8 +172,9 @@ static void upcastI8AllocasAndUses(Instruction &I,
160172
if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
161173
return;
162174

163-
std::optional<Type *> TargetType;
164-
bool Conflict = false;
175+
Type *SmallestType = nullptr;
176+
177+
// Gather all cast targets
165178
for (User *U : AI->users()) {
166179
auto *Load = dyn_cast<LoadInst>(U);
167180
if (!Load)
@@ -170,21 +183,20 @@ static void upcastI8AllocasAndUses(Instruction &I,
170183
auto *Cast = dyn_cast<CastInst>(LU);
171184
if (!Cast)
172185
continue;
173-
Type *T = Cast->getType();
174-
if (!TargetType)
175-
TargetType = T;
176-
177-
if (TargetType.value() != T) {
178-
Conflict = true;
179-
break;
180-
}
186+
Type *Ty = Cast->getType();
187+
if (!SmallestType ||
188+
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
189+
SmallestType = Ty;
181190
}
182191
}
183-
if (!TargetType || Conflict)
184-
return;
185192

193+
if (!SmallestType)
194+
return; // no valid casts found
195+
196+
// Replace alloca
186197
IRBuilder<> Builder(AI);
187-
AllocaInst *NewAlloca = Builder.CreateAlloca(TargetType.value());
198+
auto *NewAlloca =
199+
Builder.CreateAlloca(SmallestType);
188200
ReplacedValues[AI] = NewAlloca;
189201
ToRemove.push_back(AI);
190202
}
Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,99 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
12
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
23

34
define void @const_i8_store() {
4-
%accum.i.flat = alloca [1 x i32], align 4
5-
%i = alloca i8, align 4
6-
store i8 1, ptr %i
7-
%i8.load = load i8, ptr %i
8-
%z = zext i8 %i8.load to i32
9-
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
10-
store i32 %z, ptr %gep, align 4
11-
ret void
5+
; CHECK-LABEL: define void @const_i8_store() {
6+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
7+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
8+
; CHECK-NEXT: store i32 1, ptr [[TMP1]], align 4
9+
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
10+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
11+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
12+
; CHECK-NEXT: ret void
13+
;
14+
%accum.i.flat = alloca [1 x i32], align 4
15+
%i = alloca i8, align 4
16+
store i8 1, ptr %i
17+
%i8.load = load i8, ptr %i
18+
%z = zext i8 %i8.load to i32
19+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
20+
store i32 %z, ptr %gep, align 4
21+
ret void
1222
}
1323

1424
define void @const_add_i8_store() {
15-
%accum.i.flat = alloca [1 x i32], align 4
16-
%i = alloca i8, align 4
17-
%add_i8 = add nsw i8 3, 1
18-
store i8 %add_i8, ptr %i
19-
%i8.load = load i8, ptr %i
20-
%z = zext i8 %i8.load to i32
21-
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
22-
store i32 %z, ptr %gep, align 4
23-
ret void
25+
; CHECK-LABEL: define void @const_add_i8_store() {
26+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
27+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
28+
; CHECK-NEXT: store i32 4, ptr [[TMP1]], align 4
29+
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
30+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
31+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[GEP]], align 4
32+
; CHECK-NEXT: ret void
33+
;
34+
%accum.i.flat = alloca [1 x i32], align 4
35+
%i = alloca i8, align 4
36+
%add_i8 = add nsw i8 3, 1
37+
store i8 %add_i8, ptr %i
38+
%i8.load = load i8, ptr %i
39+
%z = zext i8 %i8.load to i32
40+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
41+
store i32 %z, ptr %gep, align 4
42+
ret void
2443
}
2544

2645
define void @var_i8_store(i1 %cmp.i8) {
27-
%accum.i.flat = alloca [1 x i32], align 4
28-
%i = alloca i8, align 4
29-
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
30-
store i8 %select.i8, ptr %i
31-
%i8.load = load i8, ptr %i
32-
%z = zext i8 %i8.load to i32
33-
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
34-
store i32 %z, ptr %gep, align 4
35-
ret void
46+
; CHECK-LABEL: define void @var_i8_store(
47+
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
48+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
49+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i32, align 4
50+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
51+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
52+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP1]], align 4
53+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
54+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP]], align 4
55+
; CHECK-NEXT: ret void
56+
;
57+
%accum.i.flat = alloca [1 x i32], align 4
58+
%i = alloca i8, align 4
59+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
60+
store i8 %select.i8, ptr %i
61+
%i8.load = load i8, ptr %i
62+
%z = zext i8 %i8.load to i32
63+
%gep = getelementptr i32, ptr %accum.i.flat, i32 0
64+
store i32 %z, ptr %gep, align 4
65+
ret void
3666
}
3767

3868
define void @conflicting_cast(i1 %cmp.i8) {
39-
%accum.i.flat = alloca [2 x i32], align 4
40-
%i = alloca i8, align 4
41-
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
42-
store i8 %select.i8, ptr %i
43-
%i8.load = load i8, ptr %i
44-
%z = zext i8 %i8.load to i16
45-
%gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
46-
store i16 %z, ptr %gep1, align 2
47-
%gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
48-
store i16 %z, ptr %gep2, align 2
49-
%z2 = zext i8 %i8.load to i32
50-
%gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
51-
store i32 %z2, ptr %gep3, align 4
52-
ret void
53-
}
69+
; CHECK-LABEL: define void @conflicting_cast(
70+
; CHECK-SAME: i1 [[CMP_I8:%.*]]) {
71+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x i32], align 4
72+
; CHECK-NEXT: [[TMP1:%.*]] = alloca i16, align 2
73+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP_I8]], i32 1, i32 2
74+
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP1]], align 4
75+
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[TMP1]], align 2
76+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 0
77+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
78+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[ACCUM_I_FLAT]], i32 1
79+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP2]], align 2
80+
; CHECK-NEXT: [[TMP4:%.*]] = zext i16 [[TMP3]] to i32
81+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 1
82+
; CHECK-NEXT: store i32 [[TMP4]], ptr [[GEP3]], align 4
83+
; CHECK-NEXT: ret void
84+
;
85+
%accum.i.flat = alloca [2 x i32], align 4
86+
%i = alloca i8, align 4
87+
%select.i8 = select i1 %cmp.i8, i8 1, i8 2
88+
store i8 %select.i8, ptr %i
89+
%i8.load = load i8, ptr %i
90+
%z = zext i8 %i8.load to i16
91+
%gep1 = getelementptr i16, ptr %accum.i.flat, i32 0
92+
store i16 %z, ptr %gep1, align 2
93+
%gep2 = getelementptr i16, ptr %accum.i.flat, i32 1
94+
store i16 %z, ptr %gep2, align 2
95+
%z2 = zext i8 %i8.load to i32
96+
%gep3 = getelementptr i32, ptr %accum.i.flat, i32 1
97+
store i32 %z2, ptr %gep3, align 4
98+
ret void
99+
}

0 commit comments

Comments
 (0)