Skip to content

Commit e2c1da3

Browse files
Support macros in pattern position
1 parent bd675c8 commit e2c1da3

File tree

7 files changed

+88
-10
lines changed

7 files changed

+88
-10
lines changed

crates/hir_def/src/body/lower.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,9 @@ impl ExprCollector<'_> {
531531
}
532532
}
533533
ast::Expr::MacroCall(e) => {
534+
let macro_ptr = AstPtr::new(&e);
534535
let mut ids = vec![];
535-
self.collect_macro_call(e, syntax_ptr.clone(), true, |this, expansion| {
536+
self.collect_macro_call(e, macro_ptr, true, |this, expansion| {
536537
ids.push(match expansion {
537538
Some(it) => this.collect_expr(it),
538539
None => this.alloc_expr(Expr::Missing, syntax_ptr.clone()),
@@ -555,7 +556,7 @@ impl ExprCollector<'_> {
555556
fn collect_macro_call<F: FnMut(&mut Self, Option<T>), T: ast::AstNode>(
556557
&mut self,
557558
e: ast::MacroCall,
558-
syntax_ptr: AstPtr<ast::Expr>,
559+
syntax_ptr: AstPtr<ast::MacroCall>,
559560
is_error_recoverable: bool,
560561
mut collector: F,
561562
) {
@@ -643,10 +644,14 @@ impl ExprCollector<'_> {
643644

644645
// Note that macro could be expended to multiple statements
645646
if let Some(ast::Expr::MacroCall(m)) = stmt.expr() {
647+
let macro_ptr = AstPtr::new(&m);
646648
let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
647649

648-
self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
649-
match expansion {
650+
self.collect_macro_call(
651+
m,
652+
macro_ptr,
653+
false,
654+
|this, expansion| match expansion {
650655
Some(expansion) => {
651656
let statements: ast::MacroStmts = expansion;
652657

@@ -660,8 +665,8 @@ impl ExprCollector<'_> {
660665
let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
661666
this.statements_in_scope.push(Statement::Expr(expr));
662667
}
663-
}
664-
});
668+
},
669+
);
665670
} else {
666671
let expr = self.collect_expr_opt(stmt.expr());
667672
self.statements_in_scope.push(Statement::Expr(expr));
@@ -848,8 +853,23 @@ impl ExprCollector<'_> {
848853
Pat::Missing
849854
}
850855
}
856+
ast::Pat::MacroPat(mac) => match mac.macro_call() {
857+
Some(call) => {
858+
let macro_ptr = AstPtr::new(&call);
859+
let mut pat = None;
860+
self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
861+
pat = Some(this.collect_pat_opt(expanded_pat));
862+
});
863+
864+
match pat {
865+
Some(pat) => return pat,
866+
None => Pat::Missing,
867+
}
868+
}
869+
None => Pat::Missing,
870+
},
851871
// FIXME: implement
852-
ast::Pat::RangePat(_) | ast::Pat::MacroPat(_) => Pat::Missing,
872+
ast::Pat::RangePat(_) => Pat::Missing,
853873
};
854874
let ptr = AstPtr::new(&pat);
855875
self.alloc_pat(pattern, Either::Left(ptr))

crates/hir_def/src/item_tree.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ impl ItemTree {
9999
// items.
100100
ctx.lower_macro_stmts(stmts)
101101
},
102+
ast::Pat(_pat) => {
103+
// FIXME: This occurs because macros in pattern position are treated as inner
104+
// items and expanded during block DefMap computation
105+
return Default::default();
106+
},
102107
ast::Expr(e) => {
103108
// Macros can expand to expressions. We return an empty item tree in this case, but
104109
// still need to collect inner items.

crates/hir_def/src/item_tree/lower.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ impl Ctx {
189189
block_stack.push(self.source_ast_id_map.ast_id(&block));
190190
},
191191
ast::Item(item) => {
192-
// FIXME: This triggers for macro calls in expression position
192+
// FIXME: This triggers for macro calls in expression/pattern/type position
193193
let mod_items = self.lower_mod_item(&item, true);
194194
let current_block = block_stack.last();
195195
if let (Some(mod_items), Some(block)) = (mod_items, current_block) {

crates/hir_expand/src/db.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind {
439439
match parent.kind() {
440440
MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items,
441441
MACRO_STMTS => FragmentKind::Statements,
442+
MACRO_PAT => FragmentKind::Pattern,
442443
ITEM_LIST => FragmentKind::Items,
443444
LET_STMT => {
444445
// FIXME: Handle LHS Pattern

crates/hir_ty/src/tests/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,11 +1065,11 @@ fn macro_in_arm() {
10651065
}
10661066
"#,
10671067
expect![[r#"
1068+
!0..2 '()': ()
10681069
51..110 '{ ... }; }': ()
10691070
61..62 'x': u32
10701071
65..107 'match ... }': u32
10711072
71..73 '()': ()
1072-
84..91 'unit!()': ()
10731073
95..100 '92u32': u32
10741074
"#]],
10751075
);

crates/hir_ty/src/tests/patterns.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use expect_test::expect;
22

3-
use super::{check_infer, check_infer_with_mismatches};
3+
use super::{check_infer, check_infer_with_mismatches, check_types};
44

55
#[test]
66
fn infer_pattern() {
@@ -825,3 +825,29 @@ fn foo(foo: Foo) {
825825
"#]],
826826
);
827827
}
828+
829+
#[test]
830+
fn macro_pat() {
831+
check_types(
832+
r#"
833+
macro_rules! pat {
834+
($name:ident) => { Enum::Variant1($name) }
835+
}
836+
837+
enum Enum {
838+
Variant1(u8),
839+
Variant2,
840+
}
841+
842+
fn f(e: Enum) {
843+
match e {
844+
pat!(bind) => {
845+
bind;
846+
//^^^^ u8
847+
}
848+
Enum::Variant2 => {}
849+
}
850+
}
851+
"#,
852+
)
853+
}

crates/ide/src/goto_definition.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,32 @@ pub mod theitem {
11851185
pub fn gimme() -> theitem::TheItem {
11861186
theitem::TheItem
11871187
}
1188+
"#,
1189+
);
1190+
}
1191+
1192+
#[test]
1193+
fn goto_ident_from_pat_macro() {
1194+
check(
1195+
r#"
1196+
macro_rules! pat {
1197+
($name:ident) => { Enum::Variant1($name) }
1198+
}
1199+
1200+
enum Enum {
1201+
Variant1(u8),
1202+
Variant2,
1203+
}
1204+
1205+
fn f(e: Enum) {
1206+
match e {
1207+
pat!(bind) => {
1208+
//^^^^
1209+
bind$0
1210+
}
1211+
Enum::Variant2 => {}
1212+
}
1213+
}
11881214
"#,
11891215
);
11901216
}

0 commit comments

Comments
 (0)