@@ -49,9 +49,11 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
49
49
// using iterators and peek()?
50
50
fn match_args_from_caller_to_enzyme < ' ll > (
51
51
cx : & SimpleCx < ' ll > ,
52
+ width : u32 ,
52
53
args : & mut Vec < & ' ll llvm:: Value > ,
53
54
inputs : & [ DiffActivity ] ,
54
55
outer_args : & [ & ' ll llvm:: Value ] ,
56
+ has_sret : bool ,
55
57
) {
56
58
debug ! ( "matching autodiff arguments" ) ;
57
59
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -63,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
63
65
let mut outer_pos: usize = 0 ;
64
66
let mut activity_pos = 0 ;
65
67
68
+ if has_sret {
69
+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70
+ // inner function will still return something. We increase our outer_pos by one,
71
+ // and once we're done with all other args we will take the return of the inner call and
72
+ // update the sret pointer with it
73
+ outer_pos = 1 ;
74
+ }
75
+
66
76
let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
67
77
let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
68
78
let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -114,21 +124,21 @@ fn match_args_from_caller_to_enzyme<'ll>(
114
124
assert ! ( unsafe {
115
125
llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
116
126
} ) ;
117
- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
118
- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
119
- assert ! ( unsafe {
120
- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
121
- } ) ;
122
- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
123
- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
124
- assert ! ( unsafe {
125
- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
126
- } ) ;
127
- args. push ( next_outer_arg2) ;
127
+
128
+ for _ in 0 ..width {
129
+ let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130
+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131
+ assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) } == llvm:: TypeKind :: Pointer ) ;
132
+ let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
133
+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134
+ assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) } == llvm:: TypeKind :: Integer ) ;
135
+ args. push ( next_outer_arg2) ;
136
+ }
128
137
args. push ( cx. get_metadata_value ( enzyme_const) ) ;
129
138
args. push ( next_outer_arg) ;
130
- outer_pos += 4 ;
139
+ outer_pos += 2 + 2 * width as usize ;
131
140
activity_pos += 2 ;
141
+
132
142
} else {
133
143
// A duplicated pointer will have the following two outer_fn arguments:
134
144
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -144,6 +154,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
144
154
args. push ( next_outer_arg) ;
145
155
outer_pos += 2 ;
146
156
activity_pos += 1 ;
157
+
158
+ // Now, if width > 1, we need to account for that
159
+ for _ in 1 ..width {
160
+ let next_outer_arg = outer_args[ outer_pos] ;
161
+ args. push ( next_outer_arg) ;
162
+ outer_pos += 1 ;
163
+ }
164
+
147
165
}
148
166
} else {
149
167
// We do not differentiate with resprect to this argument.
@@ -324,14 +342,20 @@ fn generate_enzyme_call<'ll>(
324
342
if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
325
343
args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
326
344
}
345
+ if attrs. width > 1 {
346
+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
347
+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
348
+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
349
+ }
327
350
351
+ let has_sret = has_sret ( outer_fn) ;
328
352
let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
329
- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
353
+ match_args_from_caller_to_enzyme ( & cx, attrs . width , & mut args, & attrs. input_activity , & outer_args, has_sret ) ;
330
354
331
355
let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
332
356
333
357
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
334
- // metadata attachted to it, but we just created this code oota. Given that the
358
+ // metadata attached to it, but we just created this code oota. Given that the
335
359
// differentiated function already has partly confusing metadata, and given that this
336
360
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
337
361
// dummy code which we inserted at a higher level.
@@ -352,8 +376,6 @@ fn generate_enzyme_call<'ll>(
352
376
// Now that we copied the metadata, get rid of dummy code.
353
377
llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
354
378
355
- let has_sret = has_sret ( outer_fn) ;
356
-
357
379
if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
358
380
if has_sret {
359
381
// This is what we already have in our outer_fn (shortened):
0 commit comments