1
1
use std:: ptr;
2
+
2
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
3
4
use rustc_codegen_ssa:: ModuleCodegen ;
4
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
@@ -14,7 +15,6 @@ use crate::errors::{AutoDiffWithoutEnable, LlvmError};
14
15
use crate :: llvm:: AttributePlace :: Function ;
15
16
use crate :: llvm:: { Metadata , True } ;
16
17
use crate :: value:: Value ;
17
-
18
18
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
19
19
20
20
fn get_params ( fnc : & Value ) -> Vec < & Value > {
@@ -28,14 +28,14 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
28
28
}
29
29
}
30
30
31
- fn has_sret ( fnc : & Value ) -> bool {
32
- let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33
- if num_args == 0 {
34
- false
35
- } else {
36
- unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37
- }
31
+ fn has_sret ( fnc : & Value ) -> bool {
32
+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33
+ if num_args == 0 {
34
+ false
35
+ } else {
36
+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
38
37
}
38
+ }
39
39
40
40
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41
41
// original inputs, as well as metadata and the additional shadow arguments.
@@ -128,17 +128,22 @@ fn match_args_from_caller_to_enzyme<'ll>(
128
128
for _ in 0 ..width {
129
129
let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130
130
let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131
- assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) } == llvm:: TypeKind :: Pointer ) ;
131
+ assert ! (
132
+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133
+ == llvm:: TypeKind :: Pointer
134
+ ) ;
132
135
let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
133
136
let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134
- assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) } == llvm:: TypeKind :: Integer ) ;
137
+ assert ! (
138
+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139
+ == llvm:: TypeKind :: Integer
140
+ ) ;
135
141
args. push ( next_outer_arg2) ;
136
142
}
137
143
args. push ( cx. get_metadata_value ( enzyme_const) ) ;
138
144
args. push ( next_outer_arg) ;
139
145
outer_pos += 2 + 2 * width as usize ;
140
146
activity_pos += 2 ;
141
-
142
147
} else {
143
148
// A duplicated pointer will have the following two outer_fn arguments:
144
149
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -161,7 +166,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
161
166
args. push ( next_outer_arg) ;
162
167
outer_pos += 1 ;
163
168
}
164
-
165
169
}
166
170
} else {
167
171
// We do not differentiate with resprect to this argument.
@@ -172,7 +176,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
172
176
}
173
177
}
174
178
175
-
176
179
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
177
180
// arguments. We do however need to declare them with their correct return type.
178
181
// We already figured the correct return type out in our frontend, when generating the outer_fn,
@@ -350,7 +353,14 @@ fn generate_enzyme_call<'ll>(
350
353
351
354
let has_sret = has_sret ( outer_fn) ;
352
355
let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
353
- match_args_from_caller_to_enzyme ( & cx, attrs. width , & mut args, & attrs. input_activity , & outer_args, has_sret) ;
356
+ match_args_from_caller_to_enzyme (
357
+ & cx,
358
+ attrs. width ,
359
+ & mut args,
360
+ & attrs. input_activity ,
361
+ & outer_args,
362
+ has_sret,
363
+ ) ;
354
364
355
365
let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
356
366
0 commit comments