Skip to content

Commit 89aa469

Browse files
bpo-38870: Add docstring support to ast.unparse (GH-17760)
Allow ast.unparse to detect docstrings in functions, modules and classes and produce nicely formatted unparsed output for said docstrings. Co-Authored-By: Pablo Galindo <[email protected]>
1 parent 66b7973 commit 89aa469

File tree

2 files changed

+171
-80
lines changed

2 files changed

+171
-80
lines changed

Lib/ast.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,22 @@ def set_precedence(self, precedence, *nodes):
667667
for node in nodes:
668668
self._precedences[node] = precedence
669669

670+
def get_raw_docstring(self, node):
671+
"""If a docstring node is found in the body of the *node* parameter,
672+
return that docstring node, None otherwise.
673+
674+
Logic mirrored from ``_PyAST_GetDocString``."""
675+
if not isinstance(
676+
node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
677+
) or len(node.body) < 1:
678+
return None
679+
node = node.body[0]
680+
if not isinstance(node, Expr):
681+
return None
682+
node = node.value
683+
if isinstance(node, Constant) and isinstance(node.value, str):
684+
return node
685+
670686
def traverse(self, node):
671687
if isinstance(node, list):
672688
for item in node:
@@ -681,9 +697,15 @@ def visit(self, node):
681697
self.traverse(node)
682698
return "".join(self._source)
683699

700+
def _write_docstring_and_traverse_body(self, node):
701+
if (docstring := self.get_raw_docstring(node)):
702+
self._write_docstring(docstring)
703+
self.traverse(node.body[1:])
704+
else:
705+
self.traverse(node.body)
706+
684707
def visit_Module(self, node):
685-
for subnode in node.body:
686-
self.traverse(subnode)
708+
self._write_docstring_and_traverse_body(node)
687709

688710
def visit_Expr(self, node):
689711
self.fill()
@@ -850,15 +872,15 @@ def visit_ClassDef(self, node):
850872
self.traverse(e)
851873

852874
with self.block():
853-
self.traverse(node.body)
875+
self._write_docstring_and_traverse_body(node)
854876

855877
def visit_FunctionDef(self, node):
856-
self.__FunctionDef_helper(node, "def")
878+
self._function_helper(node, "def")
857879

858880
def visit_AsyncFunctionDef(self, node):
859-
self.__FunctionDef_helper(node, "async def")
881+
self._function_helper(node, "async def")
860882

861-
def __FunctionDef_helper(self, node, fill_suffix):
883+
def _function_helper(self, node, fill_suffix):
862884
self.write("\n")
863885
for deco in node.decorator_list:
864886
self.fill("@")
@@ -871,15 +893,15 @@ def __FunctionDef_helper(self, node, fill_suffix):
871893
self.write(" -> ")
872894
self.traverse(node.returns)
873895
with self.block():
874-
self.traverse(node.body)
896+
self._write_docstring_and_traverse_body(node)
875897

876898
def visit_For(self, node):
877-
self.__For_helper("for ", node)
899+
self._for_helper("for ", node)
878900

879901
def visit_AsyncFor(self, node):
880-
self.__For_helper("async for ", node)
902+
self._for_helper("async for ", node)
881903

882-
def __For_helper(self, fill, node):
904+
def _for_helper(self, fill, node):
883905
self.fill(fill)
884906
self.traverse(node.target)
885907
self.write(" in ")
@@ -974,6 +996,19 @@ def _fstring_FormattedValue(self, node, write):
974996
def visit_Name(self, node):
975997
self.write(node.id)
976998

999+
def _write_docstring(self, node):
1000+
self.fill()
1001+
if node.kind == "u":
1002+
self.write("u")
1003+
1004+
# Preserve quotes in the docstring by escaping them
1005+
value = node.value.replace("\\", "\\\\")
1006+
value = value.replace('"""', '""\"')
1007+
if value[-1] == '"':
1008+
value = value.replace('"', '\\"', -1)
1009+
1010+
self.write(f'"""{value}"""')
1011+
9771012
def _write_constant(self, value):
9781013
if isinstance(value, (float, complex)):
9791014
# Substitute overflowing decimal literal for AST infinities.

0 commit comments

Comments
 (0)