Skip to content

Commit 2a243a6

Browse files
committed
fmt
1 parent 963b3ce commit 2a243a6

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::ptr;
2+
23
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
34
use rustc_codegen_ssa::ModuleCodegen;
45
use rustc_codegen_ssa::back::write::ModuleConfig;
@@ -14,7 +15,6 @@ use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1415
use crate::llvm::AttributePlace::Function;
1516
use crate::llvm::{Metadata, True};
1617
use crate::value::Value;
17-
1818
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
1919

2020
fn get_params(fnc: &Value) -> Vec<&Value> {
@@ -28,14 +28,14 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828
}
2929
}
3030

31-
fn has_sret(fnc: &Value) -> bool {
32-
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
33-
if num_args == 0 {
34-
false
35-
} else {
36-
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
37-
}
31+
fn has_sret(fnc: &Value) -> bool {
32+
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
33+
if num_args == 0 {
34+
false
35+
} else {
36+
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
3837
}
38+
}
3939

4040
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
4141
// original inputs, as well as metadata and the additional shadow arguments.
@@ -128,17 +128,22 @@ fn match_args_from_caller_to_enzyme<'ll>(
128128
for _ in 0..width {
129129
let next_outer_arg2 = outer_args[outer_pos + 2];
130130
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131-
assert!(unsafe {llvm::LLVMRustGetTypeKind(next_outer_ty2)} == llvm::TypeKind::Pointer);
131+
assert!(
132+
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty2) }
133+
== llvm::TypeKind::Pointer
134+
);
132135
let next_outer_arg3 = outer_args[outer_pos + 3];
133136
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
134-
assert!(unsafe{ llvm::LLVMRustGetTypeKind(next_outer_ty3)} == llvm::TypeKind::Integer);
137+
assert!(
138+
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty3) }
139+
== llvm::TypeKind::Integer
140+
);
135141
args.push(next_outer_arg2);
136142
}
137143
args.push(cx.get_metadata_value(enzyme_const));
138144
args.push(next_outer_arg);
139145
outer_pos += 2 + 2 * width as usize;
140146
activity_pos += 2;
141-
142147
} else {
143148
// A duplicated pointer will have the following two outer_fn arguments:
144149
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -161,7 +166,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
161166
args.push(next_outer_arg);
162167
outer_pos += 1;
163168
}
164-
165169
}
166170
} else {
167171
// We do not differentiate with resprect to this argument.
@@ -172,7 +176,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
172176
}
173177
}
174178

175-
176179
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
177180
// arguments. We do however need to declare them with their correct return type.
178181
// We already figured the correct return type out in our frontend, when generating the outer_fn,
@@ -350,7 +353,14 @@ fn generate_enzyme_call<'ll>(
350353

351354
let has_sret = has_sret(outer_fn);
352355
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
353-
match_args_from_caller_to_enzyme(&cx, attrs.width, &mut args, &attrs.input_activity, &outer_args, has_sret);
356+
match_args_from_caller_to_enzyme(
357+
&cx,
358+
attrs.width,
359+
&mut args,
360+
&attrs.input_activity,
361+
&outer_args,
362+
has_sret,
363+
);
354364

355365
let call = builder.call(enzyme_ty, ad_fn, &args, None);
356366

0 commit comments

Comments
 (0)