@@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
77
77
/// e.g. in the [JAX
78
78
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
79
79
pub mode : DiffMode ,
80
+ /// A user-provided, batching width. If not given, we will default to 1 (no batching).
81
+ /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
82
+ /// - Calling the function 50 times with a batch size of 2
83
+ /// - Calling the function 25 times with a batch size of 4,
84
+ /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
85
+ /// cache locality, better re-usal of primal values, and other optimizations.
86
+ /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
87
+ /// times, so this massively increases code size. As such, values like 1024 are unlikely to
88
+ /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
89
+ /// experiments for now and focus on documenting the implications of a large width.
90
+ pub width : u32 ,
80
91
pub ret_activity : DiffActivity ,
81
92
pub input_activity : Vec < DiffActivity > ,
82
93
}
@@ -222,13 +233,15 @@ impl AutoDiffAttrs {
222
233
pub const fn error ( ) -> Self {
223
234
AutoDiffAttrs {
224
235
mode : DiffMode :: Error ,
236
+ width : 0 ,
225
237
ret_activity : DiffActivity :: None ,
226
238
input_activity : Vec :: new ( ) ,
227
239
}
228
240
}
229
241
pub fn source ( ) -> Self {
230
242
AutoDiffAttrs {
231
243
mode : DiffMode :: Source ,
244
+ width : 0 ,
232
245
ret_activity : DiffActivity :: None ,
233
246
input_activity : Vec :: new ( ) ,
234
247
}
0 commit comments