Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 087ffd7

Browse files
committed
add the autodiff batch mode frontend
1 parent aa8f0fd commit 087ffd7

File tree

5 files changed

+236
-127
lines changed

5 files changed

+236
-127
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
81+
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
82+
/// - Calling the function 50 times with a batch size of 2
83+
/// - Calling the function 25 times with a batch size of 4,
84+
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
85+
/// cache locality, better re-usal of primal values, and other optimizations.
86+
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
87+
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
88+
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
89+
/// experiments for now and focus on documenting the implications of a large width.
90+
pub width: u32,
8091
pub ret_activity: DiffActivity,
8192
pub input_activity: Vec<DiffActivity>,
8293
}
@@ -222,13 +233,15 @@ impl AutoDiffAttrs {
222233
pub const fn error() -> Self {
223234
AutoDiffAttrs {
224235
mode: DiffMode::Error,
236+
width: 0,
225237
ret_activity: DiffActivity::None,
226238
input_activity: Vec::new(),
227239
}
228240
}
229241
pub fn source() -> Self {
230242
AutoDiffAttrs {
231243
mode: DiffMode::Source,
244+
width: 0,
232245
ret_activity: DiffActivity::None,
233246
input_activity: Vec::new(),
234247
}

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
7979
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
8080
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
8181
82+
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
8283
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
8384
.label = not applicable here
8485
.label2 = not a `struct`, `enum` or `union`

0 commit comments

Comments
 (0)