@@ -667,6 +667,22 @@ def set_precedence(self, precedence, *nodes):
667
667
for node in nodes :
668
668
self ._precedences [node ] = precedence
669
669
670
+ def get_raw_docstring (self , node ):
671
+ """If a docstring node is found in the body of the *node* parameter,
672
+ return that docstring node, None otherwise.
673
+
674
+ Logic mirrored from ``_PyAST_GetDocString``."""
675
+ if not isinstance (
676
+ node , (AsyncFunctionDef , FunctionDef , ClassDef , Module )
677
+ ) or len (node .body ) < 1 :
678
+ return None
679
+ node = node .body [0 ]
680
+ if not isinstance (node , Expr ):
681
+ return None
682
+ node = node .value
683
+ if isinstance (node , Constant ) and isinstance (node .value , str ):
684
+ return node
685
+
670
686
def traverse (self , node ):
671
687
if isinstance (node , list ):
672
688
for item in node :
@@ -681,9 +697,15 @@ def visit(self, node):
681
697
self .traverse (node )
682
698
return "" .join (self ._source )
683
699
700
+ def _write_docstring_and_traverse_body (self , node ):
701
+ if (docstring := self .get_raw_docstring (node )):
702
+ self ._write_docstring (docstring )
703
+ self .traverse (node .body [1 :])
704
+ else :
705
+ self .traverse (node .body )
706
+
684
707
def visit_Module (self , node ):
685
- for subnode in node .body :
686
- self .traverse (subnode )
708
+ self ._write_docstring_and_traverse_body (node )
687
709
688
710
def visit_Expr (self , node ):
689
711
self .fill ()
@@ -850,15 +872,15 @@ def visit_ClassDef(self, node):
850
872
self .traverse (e )
851
873
852
874
with self .block ():
853
- self .traverse (node . body )
875
+ self ._write_docstring_and_traverse_body (node )
854
876
855
877
def visit_FunctionDef (self , node ):
856
- self .__FunctionDef_helper (node , "def" )
878
+ self ._function_helper (node , "def" )
857
879
858
880
def visit_AsyncFunctionDef (self , node ):
859
- self .__FunctionDef_helper (node , "async def" )
881
+ self ._function_helper (node , "async def" )
860
882
861
- def __FunctionDef_helper (self , node , fill_suffix ):
883
+ def _function_helper (self , node , fill_suffix ):
862
884
self .write ("\n " )
863
885
for deco in node .decorator_list :
864
886
self .fill ("@" )
@@ -871,15 +893,15 @@ def __FunctionDef_helper(self, node, fill_suffix):
871
893
self .write (" -> " )
872
894
self .traverse (node .returns )
873
895
with self .block ():
874
- self .traverse (node . body )
896
+ self ._write_docstring_and_traverse_body (node )
875
897
876
898
def visit_For (self , node ):
877
- self .__For_helper ("for " , node )
899
+ self ._for_helper ("for " , node )
878
900
879
901
def visit_AsyncFor (self , node ):
880
- self .__For_helper ("async for " , node )
902
+ self ._for_helper ("async for " , node )
881
903
882
- def __For_helper (self , fill , node ):
904
+ def _for_helper (self , fill , node ):
883
905
self .fill (fill )
884
906
self .traverse (node .target )
885
907
self .write (" in " )
@@ -974,6 +996,19 @@ def _fstring_FormattedValue(self, node, write):
974
996
def visit_Name (self , node ):
975
997
self .write (node .id )
976
998
999
+ def _write_docstring (self , node ):
1000
+ self .fill ()
1001
+ if node .kind == "u" :
1002
+ self .write ("u" )
1003
+
1004
+ # Preserve quotes in the docstring by escaping them
1005
+ value = node .value .replace ("\\ " , "\\ \\ " )
1006
+ value = value .replace ('"""' , '""\" ' )
1007
+ if value [- 1 ] == '"' :
1008
+ value = value .replace ('"' , '\\ "' , - 1 )
1009
+
1010
+ self .write (f'"""{ value } """' )
1011
+
977
1012
def _write_constant (self , value ):
978
1013
if isinstance (value , (float , complex )):
979
1014
# Substitute overflowing decimal literal for AST infinities.
0 commit comments