Skip to content

Commit 996d897

Browse files
authored
Forward mode MPI (rust-lang#447)
1 parent 9232202 commit 996d897

File tree

12 files changed

+1394
-120
lines changed

12 files changed

+1394
-120
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 641 additions & 119 deletions
Large diffs are not rendered by default.

enzyme/Enzyme/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ enum class DerivativeMode {
250250
/// and to describe argument bundles.
251251
enum class ValueType {
252252
// A value that is neither a value in the original
253-
// prigram, nor the derivative.
253+
// program, nor the derivative.
254254
None = 0,
255255
// The original program value
256256
Primal = 1,
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; ModuleID = 'test/mpi.c'
4+
source_filename = "test/mpi.c"
5+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
%struct.ompi_predefined_datatype_t = type opaque
9+
%struct.ompi_predefined_communicator_t = type opaque
10+
%struct.ompi_status_public_t = type { i32, i32, i32, i32, i64 }
11+
%struct.ompi_datatype_t = type opaque
12+
%struct.ompi_communicator_t = type opaque
13+
14+
@ompi_mpi_real = external dso_local global %struct.ompi_predefined_datatype_t, align 1
15+
@ompi_mpi_comm_world = external dso_local global %struct.ompi_predefined_communicator_t, align 1
16+
@.str = private unnamed_addr constant [33 x i8] c"Process %d: vald %f, valeurd %f\0A\00", align 1
17+
@.str.1 = private unnamed_addr constant [31 x i8] c"Process %d: val %f, valeur %f\0A\00", align 1
18+
19+
; Function Attrs: nounwind uwtable
20+
define dso_local void @msg1(float* %val1, float* %val2, i32 %numprocprec, i32 %numprocsuiv, i32 %etiquette) #0 {
21+
entry:
22+
%statut = alloca %struct.ompi_status_public_t, align 8
23+
%0 = bitcast %struct.ompi_status_public_t* %statut to i8*
24+
%1 = bitcast float* %val1 to i8*
25+
%call = call i32 @MPI_Send(i8* %1, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocsuiv, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*)) #4
26+
%2 = bitcast float* %val2 to i8*
27+
%call1 = call i32 @MPI_Recv(i8* %2, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocprec, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), %struct.ompi_status_public_t* nonnull %statut) #4
28+
ret void
29+
}
30+
31+
declare dso_local i32 @MPI_Send(i8*, i32, %struct.ompi_datatype_t*, i32, i32, %struct.ompi_communicator_t*)
32+
33+
declare dso_local i32 @MPI_Recv(i8*, i32, %struct.ompi_datatype_t*, i32, i32, %struct.ompi_communicator_t*, %struct.ompi_status_public_t*)
34+
35+
; Function Attrs: nounwind uwtable
36+
define dso_local void @msg2(float* %val1, float* %val2, i32 %numprocprec, i32 %numprocsuiv, i32 %etiquette) #0 {
37+
entry:
38+
%statut = alloca %struct.ompi_status_public_t, align 8
39+
%0 = bitcast %struct.ompi_status_public_t* %statut to i8*
40+
%1 = bitcast float* %val2 to i8*
41+
%call = call i32 @MPI_Recv(i8* %1, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocprec, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), %struct.ompi_status_public_t* nonnull %statut) #4
42+
%2 = bitcast float* %val1 to i8*
43+
%call1 = call i32 @MPI_Send(i8* %2, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocsuiv, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*)) #4
44+
ret void
45+
}
46+
47+
; Function Attrs: nounwind uwtable
48+
define dso_local i32 @main(i32 %argc, i8** %argv) local_unnamed_addr #0 {
49+
entry:
50+
%argc.addr = alloca i32, align 4
51+
%argv.addr = alloca i8**, align 8
52+
%nb_processus = alloca i32, align 4
53+
%rang = alloca i32, align 4
54+
%val = alloca float, align 4
55+
%valeur = alloca float, align 4
56+
%vald = alloca float, align 4
57+
%valeurd = alloca float, align 4
58+
store i32 %argc, i32* %argc.addr, align 4, !tbaa !2
59+
store i8** %argv, i8*** %argv.addr, align 8, !tbaa !6
60+
%0 = bitcast i32* %nb_processus to i8*
61+
%1 = bitcast i32* %rang to i8*
62+
%2 = bitcast float* %val to i8*
63+
%3 = bitcast float* %valeur to i8*
64+
%call = call i32 @MPI_Init(i32* nonnull %argc.addr, i8*** nonnull %argv.addr) #4
65+
%call1 = call i32 @MPI_Comm_rank(%struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), i32* nonnull %rang) #4
66+
%call2 = call i32 @MPI_Comm_size(%struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), i32* nonnull %nb_processus) #4
67+
%4 = load i32, i32* %nb_processus, align 4, !tbaa !2
68+
%5 = load i32, i32* %rang, align 4, !tbaa !2
69+
%add = add i32 %4, -1
70+
%sub = add i32 %add, %5
71+
%rem = srem i32 %sub, %4
72+
%add3 = add nsw i32 %5, 1
73+
%rem4 = srem i32 %add3, %4
74+
%add5 = add nsw i32 %5, 1000
75+
%conv = sitofp i32 %add5 to float
76+
store float %conv, float* %val, align 4, !tbaa !8
77+
%6 = bitcast float* %vald to i8*
78+
%7 = bitcast float* %valeurd to i8*
79+
%add6 = add nsw i32 %5, 2000
80+
%conv7 = sitofp i32 %add6 to float
81+
store float %conv7, float* %valeurd, align 4, !tbaa !8
82+
%cmp = icmp eq i32 %5, 0
83+
%.sink = select i1 %cmp, i8* bitcast (void (float*, float*, i32, i32, i32)* @msg1 to i8*), i8* bitcast (void (float*, float*, i32, i32, i32)* @msg2 to i8*)
84+
call void (i8*, ...) @__enzyme_fwddiff(i8* %.sink, float* nonnull %val, float* nonnull %vald, float* nonnull %valeur, float* nonnull %valeurd, i32 %rem, i32 %rem4, i32 100) #4
85+
%8 = load i32, i32* %rang, align 4, !tbaa !2
86+
%9 = load float, float* %vald, align 4, !tbaa !8
87+
%conv9 = fpext float %9 to double
88+
%10 = load float, float* %valeurd, align 4, !tbaa !8
89+
%conv10 = fpext float %10 to double
90+
%call11 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([33 x i8], [33 x i8]* @.str, i64 0, i64 0), i32 %8, double %conv9, double %conv10)
91+
%11 = load i32, i32* %rang, align 4, !tbaa !2
92+
%12 = load float, float* %val, align 4, !tbaa !8
93+
%conv12 = fpext float %12 to double
94+
%13 = load float, float* %valeur, align 4, !tbaa !8
95+
%conv13 = fpext float %13 to double
96+
%call14 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([31 x i8], [31 x i8]* @.str.1, i64 0, i64 0), i32 %11, double %conv12, double %conv13)
97+
%call15 = call i32 @MPI_Finalize() #4
98+
ret i32 0
99+
}
100+
101+
declare dso_local i32 @MPI_Init(i32*, i8***)
102+
103+
declare dso_local i32 @MPI_Comm_rank(%struct.ompi_communicator_t*, i32*)
104+
105+
declare dso_local i32 @MPI_Comm_size(%struct.ompi_communicator_t*, i32*)
106+
107+
declare dso_local void @__enzyme_fwddiff(i8*, ...)
108+
109+
; Function Attrs: nofree nounwind
110+
declare dso_local i32 @printf(i8* nocapture readonly, ...)
111+
112+
declare dso_local i32 @MPI_Finalize()
113+
114+
attributes #0 = { nounwind uwtable }
115+
attributes #4 = { nounwind }
116+
117+
!llvm.module.flags = !{!0}
118+
!llvm.ident = !{!1}
119+
120+
!0 = !{i32 1, !"wchar_size", i32 4}
121+
!1 = !{!"clang version 10.0.1 ([email protected]:llvm/llvm-project ef32c611aa214dea855364efd7ba451ec5ec3f74)"}
122+
!2 = !{!3, !3, i64 0}
123+
!3 = !{!"int", !4, i64 0}
124+
!4 = !{!"omnipotent char", !5, i64 0}
125+
!5 = !{!"Simple C/C++ TBAA"}
126+
!6 = !{!7, !7, i64 0}
127+
!7 = !{!"any pointer", !4, i64 0}
128+
!8 = !{!9, !9, i64 0}
129+
!9 = !{!"float", !4, i64 0}
130+
131+
132+
; CHECK: define internal void @fwddiffemsg1(float* %val1, float* %"val1'", float* %val2, float* %"val2'", i32 %numprocprec, i32 %numprocsuiv, i32 %etiquette)
133+
; CHECK-NEXT: entry:
134+
; CHECK-NEXT: %statut = alloca %struct.ompi_status_public_t, align 8
135+
; CHECK-NEXT: %"'ipc" = bitcast float* %"val1'" to i8*
136+
; CHECK-NEXT: %0 = bitcast float* %val1 to i8*
137+
; CHECK-NEXT: %call = call i32 @MPI_Send(i8* %0, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocsuiv, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*))
138+
; CHECK-NEXT: %1 = call i32 @MPI_Send(i8* %"'ipc", i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocsuiv, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*))
139+
; CHECK-NEXT: %2 = bitcast float* %val2 to i8*
140+
; CHECK-NEXT: %call1 = call i32 @MPI_Recv(i8* %2, i32 1, %struct.ompi_datatype_t* bitcast (%struct.ompi_predefined_datatype_t* @ompi_mpi_real to %struct.ompi_datatype_t*), i32 %numprocprec, i32 %etiquette, %struct.ompi_communicator_t* bitcast (%struct.ompi_predefined_communicator_t* @ompi_mpi_comm_world to %struct.ompi_communicator_t*), %struct.ompi_status_public_t* nonnull %statut)
141+
; CHECK-NEXT: ret void
142+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)