Skip to content

Commit 95fb634

Browse files
committed
lower batch width to our enzyme backend
1 parent 80668bb commit 95fb634

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
4949
// using iterators and peek()?
5050
fn match_args_from_caller_to_enzyme<'ll>(
5151
cx: &SimpleCx<'ll>,
52+
width: u32,
5253
args: &mut Vec<&'ll llvm::Value>,
5354
inputs: &[DiffActivity],
5455
outer_args: &[&'ll llvm::Value],
56+
has_sret: bool,
5557
) {
5658
debug!("matching autodiff arguments");
5759
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -63,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
6365
let mut outer_pos: usize = 0;
6466
let mut activity_pos = 0;
6567

68+
if has_sret {
69+
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70+
// inner function will still return something. We increase our outer_pos by one,
71+
// and once we're done with all other args we will take the return of the inner call and
72+
// update the sret pointer with it
73+
outer_pos = 1;
74+
}
75+
6676
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
6777
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
6878
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@@ -114,21 +124,21 @@ fn match_args_from_caller_to_enzyme<'ll>(
114124
assert!(unsafe {
115125
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
116126
});
117-
let next_outer_arg2 = outer_args[outer_pos + 2];
118-
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
119-
assert!(unsafe {
120-
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
121-
});
122-
let next_outer_arg3 = outer_args[outer_pos + 3];
123-
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
124-
assert!(unsafe {
125-
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
126-
});
127-
args.push(next_outer_arg2);
127+
128+
for _ in 0..width {
129+
let next_outer_arg2 = outer_args[outer_pos + 2];
130+
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131+
assert!(unsafe {llvm::LLVMRustGetTypeKind(next_outer_ty2)} == llvm::TypeKind::Pointer);
132+
let next_outer_arg3 = outer_args[outer_pos + 3];
133+
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
134+
assert!(unsafe{ llvm::LLVMRustGetTypeKind(next_outer_ty3)} == llvm::TypeKind::Integer);
135+
args.push(next_outer_arg2);
136+
}
128137
args.push(cx.get_metadata_value(enzyme_const));
129138
args.push(next_outer_arg);
130-
outer_pos += 4;
139+
outer_pos += 2 + 2 * width as usize;
131140
activity_pos += 2;
141+
132142
} else {
133143
// A duplicated pointer will have the following two outer_fn arguments:
134144
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -144,6 +154,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
144154
args.push(next_outer_arg);
145155
outer_pos += 2;
146156
activity_pos += 1;
157+
158+
// Now, if width > 1, we need to account for that
159+
for _ in 1..width {
160+
let next_outer_arg = outer_args[outer_pos];
161+
args.push(next_outer_arg);
162+
outer_pos += 1;
163+
}
164+
147165
}
148166
} else {
149167
// We do not differentiate with resprect to this argument.
@@ -324,14 +342,20 @@ fn generate_enzyme_call<'ll>(
324342
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
325343
args.push(cx.get_metadata_value(enzyme_primal_ret));
326344
}
345+
if attrs.width > 1 {
346+
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
347+
args.push(cx.get_metadata_value(enzyme_width));
348+
args.push(cx.get_const_i64(attrs.width as u64));
349+
}
327350

351+
let has_sret = has_sret(outer_fn);
328352
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
329-
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
353+
match_args_from_caller_to_enzyme(&cx, attrs.width, &mut args, &attrs.input_activity, &outer_args, has_sret);
330354

331355
let call = builder.call(enzyme_ty, ad_fn, &args, None);
332356

333357
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
334-
// metadata attachted to it, but we just created this code oota. Given that the
358+
// metadata attached to it, but we just created this code oota. Given that the
335359
// differentiated function already has partly confusing metadata, and given that this
336360
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
337361
// dummy code which we inserted at a higher level.
@@ -352,8 +376,6 @@ fn generate_enzyme_call<'ll>(
352376
// Now that we copied the metadata, get rid of dummy code.
353377
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
354378

355-
let has_sret = has_sret(outer_fn);
356-
357379
if cx.val_ty(call) == cx.type_void() || has_sret {
358380
if has_sret {
359381
// This is what we already have in our outer_fn (shortened):

0 commit comments

Comments
 (0)