Skip to content

Commit 51e3556

Browse files
committed
Initial naive implementation using Symbols to represent autodiff modes (Forward, Reverse)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next.
1 parent 3805573 commit 51e3556

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,29 +259,41 @@ mod llvm_enzyme {
259259
// create TokenStream from vec elemtents:
260260
// meta_item doesn't have a .tokens field
261261
let mut ts: Vec<TokenTree> = vec![];
262-
if meta_item_vec.len() < 2 {
263-
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
264-
// input and output args.
262+
if meta_item_vec.len() < 1 {
263+
// At the bare minimum, we need a fnc name.
265264
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
266265
return vec![item];
267266
}
268267

269-
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
268+
let mode_symbol = match mode {
269+
DiffMode::Forward => sym::Forward,
270+
DiffMode::Reverse => sym::Reverse,
271+
_ => unreachable!("Unsupported mode: {:?}", mode),
272+
};
273+
274+
// Insert mode token
275+
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
276+
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
277+
ts.insert(
278+
1,
279+
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
280+
);
270281

271282
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
272283
// If it is not given, we default to 1 (scalar mode).
273284
let start_position;
274285
let kind: LitKind = LitKind::Integer;
275286
let symbol;
276-
if meta_item_vec.len() >= 3
277-
&& let Some(width) = width(&meta_item_vec[2])
287+
if meta_item_vec.len() >= 2
288+
&& let Some(width) = width(&meta_item_vec[1])
278289
{
279-
start_position = 3;
290+
start_position = 2;
280291
symbol = Symbol::intern(&width.to_string());
281292
} else {
282-
start_position = 2;
293+
start_position = 1;
283294
symbol = sym::integer(1);
284295
}
296+
285297
let l: Lit = Lit { kind, symbol, suffix: None };
286298
let t = Token::new(TokenKind::Literal(l), Span::default());
287299
let comma = Token::new(TokenKind::Comma, Span::default());

compiler/rustc_span/src/symbol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ symbols! {
253253
FnMut,
254254
FnOnce,
255255
Formatter,
256+
Forward,
256257
From,
257258
FromIterator,
258259
FromResidual,
@@ -348,6 +349,7 @@ symbols! {
348349
Result,
349350
ResumeTy,
350351
Return,
352+
Reverse,
351353
Right,
352354
Rust,
353355
RustaceansAreAwesome,

0 commit comments

Comments
 (0)