Skip to content

Commit 2a49c14

Browse files
committed
wip: So yeah, tests pass, but still eval_always (not far from disk caching though I think)
1 parent 9b1a7a9 commit 2a49c14

File tree

13 files changed

+288
-33
lines changed

13 files changed

+288
-33
lines changed

compiler/rustc_ast/src/tokenstream.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
//! ownership of the original.
1515
1616
use std::borrow::Cow;
17+
use std::hash::Hash;
1718
use std::{cmp, fmt, iter};
1819

1920
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
2021
use rustc_data_structures::sync::{self, Lrc};
2122
use rustc_macros::{Decodable, Encodable, HashStable_Generic};
22-
use rustc_serialize::{Decodable, Encodable};
23+
use rustc_serialize::{Decodable, Encodable, Encoder};
24+
use rustc_span::def_id::{CrateNum, DefIndex};
2325
use rustc_span::{sym, Span, SpanDecoder, SpanEncoder, Symbol, DUMMY_SP};
2426

2527
use crate::ast::{AttrStyle, StmtKind};
@@ -139,8 +141,10 @@ impl fmt::Debug for LazyAttrTokenStream {
139141
}
140142

141143
impl<S: SpanEncoder> Encodable<S> for LazyAttrTokenStream {
142-
fn encode(&self, _s: &mut S) {
143-
panic!("Attempted to encode LazyAttrTokenStream");
144+
fn encode(&self, s: &mut S) {
145+
// TODO: welp
146+
// TODO: (also) `.flattened()` here?
147+
self.to_attr_token_stream().encode(s)
144148
}
145149
}
146150

@@ -296,6 +300,96 @@ pub struct AttrsTarget {
296300
#[derive(Clone, Debug, Default, Encodable, Decodable)]
297301
pub struct TokenStream(pub(crate) Lrc<Vec<TokenTree>>);
298302

303+
struct HashEncoder<H: std::hash::Hasher> {
304+
hasher: H,
305+
}
306+
307+
impl<H: std::hash::Hasher> Encoder for HashEncoder<H> {
308+
fn emit_usize(&mut self, v: usize) {
309+
self.hasher.write_usize(v)
310+
}
311+
312+
fn emit_u128(&mut self, v: u128) {
313+
self.hasher.write_u128(v)
314+
}
315+
316+
fn emit_u64(&mut self, v: u64) {
317+
self.hasher.write_u64(v)
318+
}
319+
320+
fn emit_u32(&mut self, v: u32) {
321+
self.hasher.write_u32(v)
322+
}
323+
324+
fn emit_u16(&mut self, v: u16) {
325+
self.hasher.write_u16(v)
326+
}
327+
328+
fn emit_u8(&mut self, v: u8) {
329+
self.hasher.write_u8(v)
330+
}
331+
332+
fn emit_isize(&mut self, v: isize) {
333+
self.hasher.write_isize(v)
334+
}
335+
336+
fn emit_i128(&mut self, v: i128) {
337+
self.hasher.write_i128(v)
338+
}
339+
340+
fn emit_i64(&mut self, v: i64) {
341+
self.hasher.write_i64(v)
342+
}
343+
344+
fn emit_i32(&mut self, v: i32) {
345+
self.hasher.write_i32(v)
346+
}
347+
348+
fn emit_i16(&mut self, v: i16) {
349+
self.hasher.write_i16(v)
350+
}
351+
352+
fn emit_raw_bytes(&mut self, s: &[u8]) {
353+
self.hasher.write(s)
354+
}
355+
}
356+
357+
impl<H: std::hash::Hasher> SpanEncoder for HashEncoder<H> {
358+
fn encode_span(&mut self, span: Span) {
359+
span.hash(&mut self.hasher)
360+
}
361+
362+
fn encode_symbol(&mut self, symbol: Symbol) {
363+
symbol.hash(&mut self.hasher)
364+
}
365+
366+
fn encode_expn_id(&mut self, expn_id: rustc_span::ExpnId) {
367+
expn_id.hash(&mut self.hasher)
368+
}
369+
370+
fn encode_syntax_context(&mut self, syntax_context: rustc_span::SyntaxContext) {
371+
syntax_context.hash(&mut self.hasher)
372+
}
373+
374+
fn encode_crate_num(&mut self, crate_num: CrateNum) {
375+
crate_num.hash(&mut self.hasher)
376+
}
377+
378+
fn encode_def_index(&mut self, def_index: DefIndex) {
379+
def_index.hash(&mut self.hasher)
380+
}
381+
382+
fn encode_def_id(&mut self, def_id: rustc_span::def_id::DefId) {
383+
def_id.hash(&mut self.hasher)
384+
}
385+
}
386+
387+
impl Hash for TokenStream {
388+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
389+
Encodable::encode(self, &mut HashEncoder { hasher: state });
390+
}
391+
}
392+
299393
/// Indicates whether a token can join with the following token to form a
300394
/// compound token. Used for conversions to `proc_macro::Spacing`. Also used to
301395
/// guide pretty-printing, which is where the `JointHidden` value (which isn't

compiler/rustc_expand/src/base.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,9 @@ pub trait ResolverExpand {
10721072
trait_def_id: DefId,
10731073
impl_def_id: LocalDefId,
10741074
) -> Result<Vec<(Ident, Option<Ident>)>, Indeterminate>;
1075+
1076+
fn register_proc_macro_invoc(&mut self, invoc_id: LocalExpnId, ext: Lrc<SyntaxExtension>);
1077+
fn unregister_proc_macro_invoc(&mut self, invoc_id: LocalExpnId);
10751078
}
10761079

10771080
pub trait LintStoreExpand {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// TODO: remove
2+
#![allow(dead_code)]
3+
4+
use std::cell::Cell;
5+
use std::ptr;
6+
7+
use rustc_ast::tokenstream::TokenStream;
8+
use rustc_middle::ty::TyCtxt;
9+
use rustc_span::profiling::SpannedEventArgRecorder;
10+
use rustc_span::LocalExpnId;
11+
12+
use crate::base::ExtCtxt;
13+
use crate::errors;
14+
15+
pub(super) fn expand<'tcx>(
16+
tcx: TyCtxt<'tcx>,
17+
key: (LocalExpnId, &'tcx TokenStream),
18+
) -> Result<&'tcx TokenStream, ()> {
19+
let (invoc_id, input) = key;
20+
21+
let res = with_context(|(ecx, client)| {
22+
let span = invoc_id.expn_data().call_site;
23+
let _timer =
24+
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
25+
recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
26+
});
27+
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
28+
let strategy = crate::proc_macro::exec_strategy(ecx);
29+
let server = crate::proc_macro_server::Rustc::new(ecx);
30+
let res = match client.run(&strategy, server, input.clone(), proc_macro_backtrace) {
31+
// TODO: without flattened some (weird) tests fail, but no idea if it's correct/enough
32+
Ok(stream) => Ok(tcx.arena.alloc(stream.flattened()) as &TokenStream),
33+
Err(e) => {
34+
ecx.dcx().emit_err({
35+
errors::ProcMacroDerivePanicked {
36+
span,
37+
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
38+
message: message.into(),
39+
}),
40+
}
41+
});
42+
Err(())
43+
}
44+
};
45+
res
46+
});
47+
48+
res
49+
}
50+
51+
type CLIENT = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
52+
53+
// based on rust/compiler/rustc_middle/src/ty/context/tls.rs
54+
// #[cfg(not(parallel_compiler))]
55+
thread_local! {
56+
/// A thread local variable that stores a pointer to the current `CONTEXT`.
57+
static TLV: Cell<(*mut (), Option<CLIENT>)> = const { Cell::new((ptr::null_mut(), None)) };
58+
}
59+
60+
#[inline]
61+
fn erase(context: &mut ExtCtxt<'_>) -> *mut () {
62+
context as *mut _ as *mut ()
63+
}
64+
65+
#[inline]
66+
unsafe fn downcast<'a>(context: *mut ()) -> &'a mut ExtCtxt<'a> {
67+
unsafe { &mut *(context as *mut ExtCtxt<'a>) }
68+
}
69+
70+
/// Sets `context` as the new current `CONTEXT` for the duration of the function `f`.
71+
#[inline]
72+
pub fn enter_context<'a, F, R>(context: (&mut ExtCtxt<'a>, CLIENT), f: F) -> R
73+
where
74+
F: FnOnce() -> R,
75+
{
76+
let (ectx, client) = context;
77+
let erased = (erase(ectx), Some(client));
78+
TLV.with(|tlv| {
79+
let old = tlv.replace(erased);
80+
let _reset = rustc_data_structures::defer(move || tlv.set(old));
81+
f()
82+
})
83+
}
84+
85+
/// Allows access to the current `CONTEXT` in a closure if one is available.
86+
#[inline]
87+
#[track_caller]
88+
pub fn with_context_opt<F, R>(f: F) -> R
89+
where
90+
F: for<'a, 'b> FnOnce(Option<&'b mut (&mut ExtCtxt<'a>, CLIENT)>) -> R,
91+
{
92+
let (ectx, client_opt) = TLV.get();
93+
if ectx.is_null() {
94+
f(None)
95+
} else {
96+
// We could get an `CONTEXT` pointer from another thread.
97+
// Ensure that `CONTEXT` is `DynSync`.
98+
// TODO: we should not be able to?
99+
// sync::assert_dyn_sync::<CONTEXT<'_>>();
100+
101+
unsafe { f(Some(&mut (downcast(ectx), client_opt.unwrap()))) }
102+
}
103+
}
104+
105+
/// Allows access to the current `CONTEXT`.
106+
/// Panics if there is no `CONTEXT` available.
107+
#[inline]
108+
pub fn with_context<F, R>(f: F) -> R
109+
where
110+
F: for<'a, 'b> FnOnce(&'b mut (&mut ExtCtxt<'a>, CLIENT)) -> R,
111+
{
112+
with_context_opt(|opt_context| f(opt_context.expect("no CONTEXT stored in tls")))
113+
}

compiler/rustc_expand/src/expand.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,14 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
794794
span,
795795
path,
796796
};
797+
self.cx
798+
.resolver
799+
.register_proc_macro_invoc(invoc.expansion_data.id, ext.clone());
800+
invoc.expansion_data.id.expn_data();
797801
let items = match expander.expand(self.cx, span, &meta, item, is_const) {
798802
ExpandResult::Ready(items) => items,
799803
ExpandResult::Retry(item) => {
804+
self.cx.resolver.unregister_proc_macro_invoc(invoc.expansion_data.id);
800805
// Reassemble the original invocation for retrying.
801806
return ExpandResult::Retry(Invocation {
802807
kind: InvocationKind::Derive { path: meta.path, item, is_const },

compiler/rustc_expand/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ mod proc_macro_server;
2828
pub use mbe::macro_rules::compile_declarative_macro;
2929
pub mod base;
3030
pub mod config;
31+
pub(crate) mod derive_macro_expansion;
3132
pub mod expand;
3233
pub mod module;
3334
// FIXME(Nilstrieb) Translate proc_macro diagnostics
3435
#[allow(rustc::untranslatable_diagnostic)]
3536
pub mod proc_macro;
3637

38+
pub fn provide(providers: &mut rustc_middle::util::Providers) {
39+
providers.derive_macro_expansion = derive_macro_expansion::expand;
40+
}
41+
3742
rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

compiler/rustc_expand/src/proc_macro.rs

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_ast as ast;
22
use rustc_ast::ptr::P;
33
use rustc_ast::tokenstream::TokenStream;
44
use rustc_errors::ErrorGuaranteed;
5+
use rustc_middle::ty;
56
use rustc_parse::parser::{ForceCollect, Parser};
67
use rustc_session::config::ProcMacroExecutionStrategy;
78
use rustc_span::profiling::SpannedEventArgRecorder;
@@ -31,7 +32,7 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
3132
}
3233
}
3334

34-
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
35+
pub fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
3536
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
3637
ecx.sess.opts.unstable_opts.proc_macro_execution_strategy
3738
== ProcMacroExecutionStrategy::CrossThread,
@@ -124,36 +125,27 @@ impl MultiItemModifier for DeriveProcMacro {
124125
// altogether. See #73345.
125126
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess);
126127
let input = item.to_tokens();
127-
let stream = {
128-
let _timer =
129-
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
130-
recorder.record_arg_with_span(
131-
ecx.sess.source_map(),
132-
ecx.expansion_descr(),
133-
span,
134-
);
135-
});
136-
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
137-
let strategy = exec_strategy(ecx);
138-
let server = proc_macro_server::Rustc::new(ecx);
139-
match self.client.run(&strategy, server, input, proc_macro_backtrace) {
140-
Ok(stream) => stream,
141-
Err(e) => {
142-
ecx.dcx().emit_err({
143-
errors::ProcMacroDerivePanicked {
144-
span,
145-
message: e.as_str().map(|message| {
146-
errors::ProcMacroDerivePanickedHelp { message: message.into() }
147-
}),
148-
}
149-
});
150-
return ExpandResult::Ready(vec![]);
151-
}
152-
}
128+
let res = ty::tls::with(|tcx| {
129+
// TODO: without flattened some (weird) tests fail, but no idea if it's correct/enough
130+
let input = tcx.arena.alloc(input.flattened()) as &TokenStream;
131+
let invoc_id = ecx.current_expansion.id;
132+
133+
assert_eq!(invoc_id.expn_data().call_site, span);
134+
135+
let res = crate::derive_macro_expansion::enter_context((ecx, self.client), move || {
136+
let res = tcx.derive_macro_expansion((invoc_id, input)).cloned();
137+
res
138+
});
139+
140+
res
141+
});
142+
let Ok(output) = res else {
143+
// error will already have been emitted
144+
return ExpandResult::Ready(vec![]);
153145
};
154146

155147
let error_count_before = ecx.dcx().err_count();
156-
let mut parser = Parser::new(&ecx.sess.psess, stream, Some("proc-macro derive"));
148+
let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
157149
let mut items = vec![];
158150

159151
loop {

compiler/rustc_interface/src/passes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ pub static DEFAULT_QUERY_PROVIDERS: LazyLock<Providers> = LazyLock::new(|| {
624624
providers.resolutions = |tcx, ()| tcx.resolver_for_lowering_raw(()).1;
625625
providers.early_lint_checks = early_lint_checks;
626626
proc_macro_decls::provide(providers);
627+
rustc_expand::provide(providers);
627628
rustc_const_eval::provide(providers);
628629
rustc_middle::hir::provide(providers);
629630
rustc_borrowck::provide(providers);

compiler/rustc_middle/src/arena.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ macro_rules! arena_types {
119119
[decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph,
120120
[] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls,
121121
[] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>,
122+
[] token_stream: rustc_ast::tokenstream::TokenStream,
122123
]);
123124
)
124125
}

compiler/rustc_middle/src/query/erase.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::intrinsics::transmute_unchecked;
22
use std::mem::MaybeUninit;
33

4+
use rustc_ast::tokenstream::TokenStream;
5+
46
use crate::query::CyclePlaceholder;
57
use crate::ty::adjustment::CoerceUnsizedInfo;
68
use crate::ty::{self, Ty};
@@ -172,6 +174,10 @@ impl EraseType for Result<ty::EarlyBinder<'_, Ty<'_>>, CyclePlaceholder> {
172174
type Result = [u8; size_of::<Result<ty::EarlyBinder<'static, Ty<'_>>, CyclePlaceholder>>()];
173175
}
174176

177+
impl EraseType for Result<&'_ TokenStream, ()> {
178+
type Result = [u8; size_of::<Result<&'static TokenStream, ()>>()];
179+
}
180+
175181
impl<T> EraseType for Option<&'_ T> {
176182
type Result = [u8; size_of::<Option<&'static ()>>()];
177183
}

0 commit comments

Comments
 (0)