1
1
import ast
2
- from dataclasses import dataclass
2
+ from dataclasses import field , dataclass
3
3
import re
4
4
from typing import Any , Dict , IO , Optional , List , Text , Tuple , Set
5
5
from enum import Enum
25
25
)
26
26
from pegen .parser_generator import ParserGenerator
27
27
28
+
28
29
EXTENSION_PREFIX = """\
29
30
#include "pegen.h"
30
31
@@ -63,7 +64,7 @@ class NodeTypes(Enum):
63
64
@dataclass
64
65
class FunctionCall :
65
66
function : str
66
- arguments : Optional [ List [Any ]] = None
67
+ arguments : List [Any ] = field ( default_factory = list )
67
68
assigned_variable : Optional [str ] = None
68
69
return_type : Optional [str ] = None
69
70
nodetype : Optional [NodeTypes ] = None
@@ -94,7 +95,7 @@ def __init__(
94
95
self .gen = parser_generator
95
96
self .exact_tokens = exact_tokens
96
97
self .non_exact_tokens = non_exact_tokens
97
- self .cache : Dict [Any , Any ] = {}
98
+ self .cache : Dict [Any , FunctionCall ] = {}
98
99
self .keyword_cache : Dict [str , int ] = {}
99
100
100
101
def keyword_helper (self , keyword : str ) -> FunctionCall :
@@ -171,7 +172,7 @@ def can_we_inline(node: Rhs) -> int:
171
172
if node in self .cache :
172
173
return self .cache [node ]
173
174
if can_we_inline (node ):
174
- self .cache [node ] = self .visit (node .alts [0 ].items [0 ])
175
+ self .cache [node ] = self .generate_call (node .alts [0 ].items [0 ])
175
176
else :
176
177
name = self .gen .name_node (node )
177
178
self .cache [node ] = FunctionCall (
@@ -183,13 +184,13 @@ def can_we_inline(node: Rhs) -> int:
183
184
return self .cache [node ]
184
185
185
186
def visit_NamedItem (self , node : NamedItem ) -> FunctionCall :
186
- call = self .visit (node .item )
187
+ call = self .generate_call (node .item )
187
188
if node .name :
188
189
call .assigned_variable = node .name
189
190
return call
190
191
191
192
def lookahead_call_helper (self , node : Lookahead , positive : int ) -> FunctionCall :
192
- call = self .visit (node .node )
193
+ call = self .generate_call (node .node )
193
194
if call .nodetype == NodeTypes .NAME_TOKEN :
194
195
return FunctionCall (
195
196
function = f"_PyPegen_lookahead_with_name" ,
@@ -217,7 +218,7 @@ def visit_NegativeLookahead(self, node: NegativeLookahead) -> FunctionCall:
217
218
return self .lookahead_call_helper (node , 0 )
218
219
219
220
def visit_Opt (self , node : Opt ) -> FunctionCall :
220
- call = self .visit (node .node )
221
+ call = self .generate_call (node .node )
221
222
return FunctionCall (
222
223
assigned_variable = "_opt_var" ,
223
224
function = call .function ,
@@ -266,7 +267,7 @@ def visit_Gather(self, node: Gather) -> FunctionCall:
266
267
return self .cache [node ]
267
268
268
269
def visit_Group (self , node : Group ) -> FunctionCall :
269
- return self .visit (node .rhs )
270
+ return self .generate_call (node .rhs )
270
271
271
272
def visit_Cut (self , node : Cut ) -> FunctionCall :
272
273
return FunctionCall (
@@ -276,6 +277,9 @@ def visit_Cut(self, node: Cut) -> FunctionCall:
276
277
nodetype = NodeTypes .CUT_OPERATOR ,
277
278
)
278
279
280
+ def generate_call (self , node : Any ) -> FunctionCall :
281
+ return super ().visit (node )
282
+
279
283
280
284
class CParserGenerator (ParserGenerator , GrammarVisitor ):
281
285
def __init__ (
@@ -317,17 +321,13 @@ def call_with_errorcheck_goto(self, call_text: str, goto_target: str) -> None:
317
321
self .print (f"goto { goto_target } ;" )
318
322
self .print (f"}}" )
319
323
320
- def out_of_memory_return (
321
- self ,
322
- expr : str ,
323
- cleanup_code : Optional [str ] = None ,
324
- ) -> None :
324
+ def out_of_memory_return (self , expr : str , cleanup_code : Optional [str ] = None ,) -> None :
325
325
self .print (f"if ({ expr } ) {{" )
326
326
with self .indent ():
327
327
if cleanup_code is not None :
328
328
self .print (cleanup_code )
329
329
self .print ("p->error_indicator = 1;" )
330
- self .print ("PyErr_NoMemory();" );
330
+ self .print ("PyErr_NoMemory();" )
331
331
self .print ("return NULL;" )
332
332
self .print (f"}}" )
333
333
@@ -484,10 +484,7 @@ def _handle_default_rule_body(self, node: Rule, rhs: Rhs, result_type: str) -> N
484
484
if any (alt .action and "EXTRA" in alt .action for alt in rhs .alts ):
485
485
self ._set_up_token_start_metadata_extraction ()
486
486
self .visit (
487
- rhs ,
488
- is_loop = False ,
489
- is_gather = node .is_gather (),
490
- rulename = node .name ,
487
+ rhs , is_loop = False , is_gather = node .is_gather (), rulename = node .name ,
491
488
)
492
489
if self .debug :
493
490
self .print ('fprintf(stderr, "Fail at %d: {node.name}\\ n", p->mark);' )
@@ -518,10 +515,7 @@ def _handle_loop_rule_body(self, node: Rule, rhs: Rhs) -> None:
518
515
if any (alt .action and "EXTRA" in alt .action for alt in rhs .alts ):
519
516
self ._set_up_token_start_metadata_extraction ()
520
517
self .visit (
521
- rhs ,
522
- is_loop = True ,
523
- is_gather = node .is_gather (),
524
- rulename = node .name ,
518
+ rhs , is_loop = True , is_gather = node .is_gather (), rulename = node .name ,
525
519
)
526
520
if is_repeat1 :
527
521
self .print ("if (_n == 0 || p->error_indicator) {" )
@@ -567,7 +561,7 @@ def visit_Rule(self, node: Rule) -> None:
567
561
self .print ("}" )
568
562
569
563
def visit_NamedItem (self , node : NamedItem ) -> None :
570
- call = self .callmakervisitor .visit (node )
564
+ call = self .callmakervisitor .generate_call (node )
571
565
if call .assigned_variable :
572
566
call .assigned_variable = self .dedupe (call .assigned_variable )
573
567
self .print (call )
@@ -674,7 +668,9 @@ def handle_alt_loop(self, node: Alt, is_gather: bool, rulename: Optional[str]) -
674
668
self .print ("if (_n == _children_capacity) {" )
675
669
with self .indent ():
676
670
self .print ("_children_capacity *= 2;" )
677
- self .print ("void **_new_children = PyMem_Realloc(_children, _children_capacity*sizeof(void *));" )
671
+ self .print (
672
+ "void **_new_children = PyMem_Realloc(_children, _children_capacity*sizeof(void *));"
673
+ )
678
674
self .out_of_memory_return (f"!_new_children" )
679
675
self .print ("_children = _new_children;" )
680
676
self .print ("}" )
@@ -721,5 +717,8 @@ def collect_vars(self, node: Alt) -> Dict[Optional[str], Optional[str]]:
721
717
return types
722
718
723
719
def add_var (self , node : NamedItem ) -> Tuple [Optional [str ], Optional [str ]]:
724
- call = self .callmakervisitor .visit (node .item )
725
- return self .dedupe (node .name if node .name else call .assigned_variable ), call .return_type
720
+ call = self .callmakervisitor .generate_call (node .item )
721
+ name = node .name if node .name else call .assigned_variable
722
+ if name is not None :
723
+ name = self .dedupe (name )
724
+ return name , call .return_type
0 commit comments