Skip to content

Commit 2bf8bfa

Browse files
authored
Merge pull request #2 from rust-lang/variant
Create variant macro
2 parents b0a9640 + ce9f579 commit 2bf8bfa

File tree

4 files changed

+265
-5
lines changed

4 files changed

+265
-5
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
# `impl_trait_utils`
1+
# impl-trait-utils
22

33
Utilities for working with impl traits in Rust.
44

5-
## `trait_transformer`
5+
## `make_variant`
66

7-
Trait transformer is an experimental crate that generates specialized versions of a base trait. For example, if you want a `Send`able version of your trait, you'd write:
7+
`make_variant` generates a specialized version of a base trait that uses `async fn` and/or `-> impl Trait`. For example, if you want a `Send`able version of your trait, you'd write:
88

99
```rust
10-
#[trait_transformer(SendIntFactory: Send)]
10+
#[trait_transformer::make_variant(SendIntFactory: Send)]
1111
trait IntFactory {
1212
async fn make(&self) -> i32;
1313
// ..or..
@@ -16,7 +16,13 @@ trait IntFactory {
1616
}
1717
```
1818

19-
Which creates a new `SendIntFactory: IntFactory + Send` trait and additionally bounds `SendIntFactory::make(): Send` and `SendIntFactory::stream(): Send`. The generated sytax is still experimental, as it relies on the nightly and unstable `async_fn_in_trait`, `return_position_impl_trait_in_trait`, and `return_type_notation` features.
19+
Which creates a new `SendIntFactory: IntFactory + Send` trait and additionally bounds `SendIntFactory::make(): Send` and `SendIntFactory::stream(): Send`. Ordinary methods are not affected.
20+
21+
Implementers of the trait can choose to implement the variant instead of the original trait. The macro creates a blanket impl which ensures that any type which implements the variant also implements the original trait.
22+
23+
## `trait_transformer`
24+
25+
`trait_transformer` does the same thing as `make_variant`, but using experimental nightly-only syntax that depends on the `return_type_notation` feature. It may be used to experiment with new kinds of trait transformations in the future.
2026

2127
#### License and usage notes
2228

trait_transformer/examples/variant.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use std::future::Future;
10+
11+
use trait_transformer::make_variant;
12+
13+
#[make_variant(SendIntFactory: Send)]
14+
trait IntFactory {
15+
const NAME: &'static str;
16+
17+
type MyFut<'a>: Future
18+
where
19+
Self: 'a;
20+
21+
async fn make(&self, x: u32, y: &str) -> i32;
22+
fn stream(&self) -> impl Iterator<Item = i32>;
23+
fn call(&self) -> u32;
24+
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
25+
}
26+
27+
fn main() {}

trait_transformer/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
#![doc = include_str!("../../README.md")]
10+
911
mod transformer;
12+
mod variant;
1013

1114
#[proc_macro_attribute]
1215
pub fn trait_transformer(
@@ -15,3 +18,11 @@ pub fn trait_transformer(
1518
) -> proc_macro::TokenStream {
1619
transformer::trait_transformer(attr, item)
1720
}
21+
22+
#[proc_macro_attribute]
23+
pub fn make_variant(
24+
attr: proc_macro::TokenStream,
25+
item: proc_macro::TokenStream,
26+
) -> proc_macro::TokenStream {
27+
variant::make_variant(attr, item)
28+
}

trait_transformer/src/variant.rs

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
// Copyright (c) 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use std::iter;
10+
11+
use proc_macro2::TokenStream;
12+
use quote::quote;
13+
use syn::{
14+
parse::{Parse, ParseStream},
15+
parse_macro_input,
16+
punctuated::Punctuated,
17+
token::Plus,
18+
Error, FnArg, Generics, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature, Token,
19+
TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeImplTrait,
20+
TypeParamBound,
21+
};
22+
23+
struct Attrs {
24+
variant: MakeVariant,
25+
}
26+
27+
impl Parse for Attrs {
28+
fn parse(input: ParseStream) -> Result<Self> {
29+
Ok(Self {
30+
variant: MakeVariant::parse(input)?,
31+
})
32+
}
33+
}
34+
35+
struct MakeVariant {
36+
name: Ident,
37+
#[allow(unused)]
38+
colon: Token![:],
39+
bounds: Punctuated<TraitBound, Plus>,
40+
}
41+
42+
impl Parse for MakeVariant {
43+
fn parse(input: ParseStream) -> Result<Self> {
44+
Ok(Self {
45+
name: input.parse()?,
46+
colon: input.parse()?,
47+
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
48+
})
49+
}
50+
}
51+
52+
pub fn make_variant(
53+
attr: proc_macro::TokenStream,
54+
item: proc_macro::TokenStream,
55+
) -> proc_macro::TokenStream {
56+
let attrs = parse_macro_input!(attr as Attrs);
57+
let item = parse_macro_input!(item as ItemTrait);
58+
59+
let variant = mk_variant(&attrs, &item);
60+
let blanket_impl = mk_blanket_impl(&attrs, &item);
61+
let output = quote! {
62+
#item
63+
#variant
64+
#blanket_impl
65+
};
66+
67+
output.into()
68+
}
69+
70+
fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
71+
let MakeVariant {
72+
ref name,
73+
colon: _,
74+
ref bounds,
75+
} = attrs.variant;
76+
let bounds: Vec<_> = bounds
77+
.into_iter()
78+
.map(|b| TypeParamBound::Trait(b.clone()))
79+
.collect();
80+
let variant = ItemTrait {
81+
ident: name.clone(),
82+
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
83+
items: tr
84+
.items
85+
.iter()
86+
.map(|item| transform_item(item, &bounds))
87+
.collect(),
88+
..tr.clone()
89+
};
90+
quote! { #variant }
91+
}
92+
93+
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
94+
// #[make_variant(SendIntFactory: Send)]
95+
// trait IntFactory {
96+
// async fn make(&self, x: u32, y: &str) -> i32;
97+
// fn stream(&self) -> impl Iterator<Item = i32>;
98+
// fn call(&self) -> u32;
99+
// }
100+
//
101+
// becomes:
102+
//
103+
// trait SendIntFactory: Send {
104+
// fn make(&self, x: u32, y: &str) -> impl ::core::future::Future<Output = i32> + Send;
105+
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
106+
// fn call(&self) -> u32;
107+
// }
108+
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
109+
return item.clone();
110+
};
111+
let (arrow, output) = if sig.asyncness.is_some() {
112+
let orig = match &sig.output {
113+
ReturnType::Default => quote! { () },
114+
ReturnType::Type(_, ty) => quote! { #ty },
115+
};
116+
let future = syn::parse2(quote! { ::core::future::Future<Output = #orig> }).unwrap();
117+
let ty = Type::ImplTrait(TypeImplTrait {
118+
impl_token: syn::parse2(quote! { impl }).unwrap(),
119+
bounds: iter::once(TypeParamBound::Trait(future))
120+
.chain(bounds.iter().cloned())
121+
.collect(),
122+
});
123+
(syn::parse2(quote! { -> }).unwrap(), ty)
124+
} else {
125+
match &sig.output {
126+
ReturnType::Type(arrow, ty) => match &**ty {
127+
Type::ImplTrait(it) => {
128+
let ty = Type::ImplTrait(TypeImplTrait {
129+
impl_token: it.impl_token,
130+
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
131+
});
132+
(*arrow, ty)
133+
}
134+
_ => return item.clone(),
135+
},
136+
ReturnType::Default => return item.clone(),
137+
}
138+
};
139+
TraitItem::Fn(TraitItemFn {
140+
sig: Signature {
141+
asyncness: None,
142+
output: ReturnType::Type(arrow, Box::new(output)),
143+
..sig.clone()
144+
},
145+
..fn_item.clone()
146+
})
147+
}
148+
149+
fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
150+
let orig = &tr.ident;
151+
let variant = &attrs.variant.name;
152+
let items = tr.items.iter().map(|item| blanket_impl_item(item, variant));
153+
quote! {
154+
impl<T> #orig for T where T: #variant {
155+
#(#items)*
156+
}
157+
}
158+
}
159+
160+
fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
161+
// impl<T> IntFactory for T where T: SendIntFactory {
162+
// const NAME: &'static str = <Self as SendIntFactory>::NAME;
163+
// type MyFut<'a> = <Self as SendIntFactory>::MyFut<'a> where Self: 'a;
164+
// async fn make(&self, x: u32, y: &str) -> i32 {
165+
// <Self as SendIntFactory>::make(self, x, y).await
166+
// }
167+
// }
168+
match item {
169+
TraitItem::Const(TraitItemConst {
170+
ident,
171+
generics,
172+
ty,
173+
..
174+
}) => {
175+
quote! {
176+
const #ident #generics: #ty = <Self as #variant>::#ident;
177+
}
178+
}
179+
TraitItem::Fn(TraitItemFn { sig, .. }) => {
180+
let ident = &sig.ident;
181+
let args = sig.inputs.iter().map(|arg| match arg {
182+
FnArg::Receiver(_) => quote! { self },
183+
FnArg::Typed(PatType { pat, .. }) => match &**pat {
184+
Pat::Ident(arg) => quote! { #arg },
185+
_ => Error::new_spanned(pat, "patterns are not supported in arguments")
186+
.to_compile_error(),
187+
},
188+
});
189+
let maybe_await = if sig.asyncness.is_some() {
190+
quote! { .await }
191+
} else {
192+
quote! {}
193+
};
194+
quote! {
195+
#sig {
196+
<Self as #variant>::#ident(#(#args),*)#maybe_await
197+
}
198+
}
199+
}
200+
TraitItem::Type(TraitItemType {
201+
ident,
202+
generics:
203+
Generics {
204+
params,
205+
where_clause,
206+
..
207+
},
208+
..
209+
}) => {
210+
quote! {
211+
type #ident<#params> = <Self as #variant>::#ident<#params> #where_clause;
212+
}
213+
}
214+
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
215+
}
216+
}

0 commit comments

Comments
 (0)