Skip to content

Commit 69a224d

Browse files
committed
Fix recursive type
1 parent 51482db commit 69a224d

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ TypeAnalyzer::TypeAnalyzer(
186186
}
187187

188188
/// Given a constant value, deduce any type information applicable
189-
TypeTree getConstantAnalysis(Constant *Val, TypeAnalyzer &TA) {
189+
TypeTree
190+
getConstantAnalysis(Constant *Val, TypeAnalyzer &TA,
191+
SmallPtrSetImpl<GlobalVariable *> *seen = nullptr) {
190192
auto &DL = TA.fntypeinfo.Function->getParent()->getDataLayout();
191193
// Undefined value is an anything everywhere
192194
if (isa<UndefValue>(Val) || isa<ConstantAggregateZero>(Val)) {
@@ -238,9 +240,10 @@ TypeTree getConstantAnalysis(Constant *Val, TypeAnalyzer &TA) {
238240

239241
int Off = (int)ai.getLimitedValue();
240242

241-
Result |= getConstantAnalysis(Op, TA).ShiftIndices(DL, /*init offset*/ 0,
242-
/*maxSize*/ ObjSize,
243-
/*addOffset*/ Off);
243+
Result |= getConstantAnalysis(Op, TA, seen)
244+
.ShiftIndices(DL, /*init offset*/ 0,
245+
/*maxSize*/ ObjSize,
246+
/*addOffset*/ Off);
244247
Off += ObjSize;
245248
}
246249
if (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
@@ -284,9 +287,10 @@ TypeTree getConstantAnalysis(Constant *Val, TypeAnalyzer &TA) {
284287

285288
int Off = (int)ai.getLimitedValue();
286289

287-
Result |= getConstantAnalysis(Op, TA).ShiftIndices(DL, /*init offset*/ 0,
288-
/*maxSize*/ ObjSize,
289-
/*addOffset*/ Off);
290+
Result |= getConstantAnalysis(Op, TA, seen)
291+
.ShiftIndices(DL, /*init offset*/ 0,
292+
/*maxSize*/ ObjSize,
293+
/*addOffset*/ Off);
290294
}
291295
if (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
292296
CD->getType()) >= 16) {
@@ -334,10 +338,10 @@ TypeTree getConstantAnalysis(Constant *Val, TypeAnalyzer &TA) {
334338
if (CE->isCast()) {
335339
if (CE->getType()->isPointerTy() && isa<ConstantInt>(CE->getOperand(0)))
336340
return TypeTree(BaseType::Anything).Only(-1);
337-
return getConstantAnalysis(CE->getOperand(0), TA);
341+
return getConstantAnalysis(CE->getOperand(0), TA, seen);
338342
}
339343
if (CE->isGEPWithNoNotionalOverIndexing()) {
340-
auto gepData0 = getConstantAnalysis(CE->getOperand(0), TA).Data0();
344+
auto gepData0 = getConstantAnalysis(CE->getOperand(0), TA, seen).Data0();
341345

342346
auto g2 = cast<GetElementPtrInst>(CE->getAsInstruction());
343347
#if LLVM_VERSION_MAJOR > 6
@@ -380,10 +384,16 @@ TypeTree getConstantAnalysis(Constant *Val, TypeAnalyzer &TA) {
380384
}
381385

382386
if (auto GV = dyn_cast<GlobalVariable>(Val)) {
387+
if (seen && seen->count(GV))
388+
return TypeTree();
383389
// A fixed constant global is a pointer to its initializer
384390
if (GV->isConstant() && GV->hasInitializer()) {
391+
SmallPtrSet<GlobalVariable *, 2> seen2;
392+
if (seen)
393+
seen2.insert(seen->begin(), seen->end());
394+
seen2.insert(GV);
385395
TypeTree Result = ConcreteType(BaseType::Pointer);
386-
Result |= getConstantAnalysis(GV->getInitializer(), TA);
396+
Result |= getConstantAnalysis(GV->getInitializer(), TA, &seen2);
387397
return Result.Only(-1);
388398
}
389399
if (GV->getName() == "__cxa_thread_atexit_impl") {
@@ -998,8 +1008,6 @@ void TypeAnalyzer::visitConstantExpr(ConstantExpr &CE) {
9981008
pointerData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1,
9991009
/*new offset*/ off);
10001010
result.insert({}, BaseType::Pointer);
1001-
llvm::errs() << "CE: " << CE << " pdata0: " << pointerData0.str()
1002-
<< " off: " << off << " res: " << result.str() << "\n";
10031011
updateAnalysis(CE.getOperand(0), result.Only(-1), &CE);
10041012
}
10051013
return;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=callee -o /dev/null | FileCheck %s
2+
3+
4+
source_filename = "Awesome.bc"
5+
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-apple-macosx11.0.0"
7+
8+
%TSf = type <{ float }>
9+
%TSi = type <{ i64 }>
10+
%swift.type_descriptor = type opaque
11+
%swift.type = type { i64 }
12+
%swift.protocol_conformance_descriptor = type { i32, i32, i32, i32 }
13+
%Ts26DefaultStringInterpolationV = type <{ %TSS }>
14+
%TSS = type <{ %Ts11_StringGutsV }>
15+
%Ts11_StringGutsV = type <{ %Ts13_StringObjectV }>
16+
%Ts13_StringObjectV = type <{ %Ts6UInt64V, %swift.bridge* }>
17+
%Ts6UInt64V = type <{ i64 }>
18+
%swift.bridge = type opaque
19+
%swift.refcounted = type { %swift.type*, i64 }
20+
%swift.opaque = type opaque
21+
22+
@"$s7Awesome5valueSfvp" = hidden global %TSf zeroinitializer, align 4
23+
@"$s7Awesome11derivativesSf_Sftvp" = hidden global <{ %TSf, %TSf }> zeroinitializer, align 4
24+
@"$s7Awesome3sumSfvp" = hidden global %TSf zeroinitializer, align 4
25+
@"$s7Awesome10iterationsSivp" = hidden local_unnamed_addr global %TSi zeroinitializer, align 8
26+
@"$ss23_ContiguousArrayStorageCMn" = external global %swift.type_descriptor, align 4
27+
@"got.$ss23_ContiguousArrayStorageCMn" = private unnamed_addr constant %swift.type_descriptor* @"$ss23_ContiguousArrayStorageCMn"
28+
@"symbolic _____yypG s23_ContiguousArrayStorageC" = linkonce_odr hidden constant <{ i8, i32, [4 x i8], i8 }> <{ i8 2, i32 trunc (i64 sub (i64 ptrtoint (%swift.type_descriptor** @"got.$ss23_ContiguousArrayStorageCMn" to i64), i64 ptrtoint (i32* getelementptr inbounds (<{ i8, i32, [4 x i8], i8 }>, <{ i8, i32, [4 x i8], i8 }>* @"symbolic _____yypG s23_ContiguousArrayStorageC", i32 0, i32 1) to i64)) to i32), [4 x i8] c"yypG", i8 0 }>, section "__TEXT,__swift5_typeref, regular, no_dead_strip", align 2
29+
@ptr = linkonce_odr hidden global { i32, i32 } { i32 ptrtoint (<{ i8, i32, [4 x i8], i8 }>* @"symbolic _____yypG s23_ContiguousArrayStorageC" to i32), i32 17 }
30+
31+
define void @callee() {
32+
entry:
33+
%loadnotype = load i32, i32* getelementptr inbounds ({ i32, i32 }, { i32, i32 }* @ptr, i64 0, i32 1), align 4
34+
ret void
35+
}
36+
37+
!5 = !{!"omnipotent char", !6, i64 0}
38+
!6 = !{!"Simple C++ TBAA"}
39+
!7 = !{!"double", !5, i64 0}
40+
!8 = !{!7, !7, i64 0}
41+
42+
43+
; CHECK: callee - {} |
44+
; CHECK-NEXT: entry
45+
; CHECK-NEXT: %loadnotype = load i32, i32* getelementptr inbounds ({ i32, i32 }, { i32, i32 }* @ptr, i64 0, i32 1), align 4: {}
46+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)