Skip to content

Commit 689d8a2

Browse files
feat(cli): codegen.function can specify subdirectories (#674)
1 parent 4ca5044 commit 689d8a2

File tree

7 files changed

+60
-22
lines changed

7 files changed

+60
-22
lines changed

.codegen/.gitignore

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@ jupyter/
88
codegen-system-prompt.txt
99

1010
# Python cache files
11-
__pycache__/
11+
**/__pycache__/
1212
*.py[cod]
1313
*$py.class
1414
*.txt
1515
*.pyc
16-
17-
# Keep codemods
18-
!codemods/
19-
!codemods/**

.codegen/codemods/no_link_backticks/no_link_backticks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from codegen import Codebase
33

44

5-
@codegen.function("no-link-backticks")
5+
@codegen.function(name="no-link-backticks", subdirectories=["test/unit"])
66
def run(codebase: Codebase):
77
import re
88

@@ -12,6 +12,7 @@ def run(codebase: Codebase):
1212
# Iterate over all .mdx files in the codebase
1313
for file in codebase.files(extensions=["mdx"]):
1414
if file.extension == ".mdx":
15+
print(f"Processing {file.path}")
1516
new_content = file.content
1617

1718
# Find all markdown links with backticks in link text

src/codegen/cli/commands/list/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@ def list_command():
1717
table.add_column("Name", style="cyan")
1818
table.add_column("Type", style="magenta")
1919
table.add_column("Path", style="dim")
20+
table.add_column("Subdirectories", style="dim")
2021

2122
for func in functions:
2223
func_type = "Webhook" if func.lint_mode else "Function"
23-
table.add_row(func.name, func_type, str(func.filepath.relative_to(Path.cwd())) if func.filepath else "<unknown>")
24+
table.add_row(
25+
func.name,
26+
func_type,
27+
str(func.filepath.relative_to(Path.cwd())) if func.filepath else "<unknown>",
28+
", ".join(func.subdirectories) if func.subdirectories else "",
29+
)
2430

2531
rich.print(table)
2632
rich.print("\nRun a function with:")

src/codegen/cli/commands/run/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
@click.command(name="run")
1414
@requires_init
1515
@click.argument("label", required=True)
16-
@click.option("--path", type=str, help="Path to build the codebase from. Defaults to the repo root.")
1716
@click.option("--web", is_flag=True, help="Run the function on the web service instead of locally")
1817
@click.option("--diff-preview", type=int, help="Show a preview of the first N lines of the diff")
1918
@click.option("--arguments", type=str, help="Arguments as a json string to pass as the function's 'arguments' parameter")
2019
def run_command(
2120
session: CodegenSession,
2221
label: str,
23-
path: str | None = None,
2422
web: bool = False,
2523
diff_preview: int | None = None,
2624
arguments: str | None = None,
@@ -59,4 +57,4 @@ def run_command(
5957
else:
6058
from codegen.cli.commands.run.run_local import run_local
6159

62-
run_local(session, codemod, diff_preview=diff_preview, path=path)
60+
run_local(session, codemod, diff_preview=diff_preview)

src/codegen/cli/commands/run/run_local.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
from codegen.cli.auth.session import CodegenSession
88
from codegen.cli.utils.function_finder import DecoratedFunction
9+
from codegen.git.repo_operator.repo_operator import RepoOperator
10+
from codegen.git.schemas.repo_config import RepoConfig
11+
from codegen.sdk.codebase.config import ProjectConfig
912
from codegen.sdk.core.codebase import Codebase
1013

1114

12-
def parse_codebase(repo_root: Path) -> Codebase:
15+
def parse_codebase(
16+
repo_path: Path,
17+
subdirectories: list[str] | None = None,
18+
) -> Codebase:
1319
"""Parse the codebase at the given root.
1420
1521
Args:
@@ -18,15 +24,21 @@ def parse_codebase(repo_root: Path) -> Codebase:
1824
Returns:
1925
Parsed Codebase object
2026
"""
21-
codebase = Codebase(repo_root)
27+
codebase = Codebase(
28+
projects=[
29+
ProjectConfig(
30+
repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)),
31+
subdirectories=subdirectories,
32+
)
33+
]
34+
)
2235
return codebase
2336

2437

2538
def run_local(
2639
session: CodegenSession,
2740
function: DecoratedFunction,
2841
diff_preview: int | None = None,
29-
path: Path | None = None,
3042
) -> None:
3143
"""Run a function locally against the codebase.
3244
@@ -36,10 +48,8 @@ def run_local(
3648
diff_preview: Number of lines of diff to preview (None for all)
3749
"""
3850
# Parse codebase and run
39-
codebase_path = f"{session.repo_path}/{path}" if path else session.repo_path
40-
41-
with Status(f"[bold]Parsing codebase at {codebase_path} ...", spinner="dots") as status:
42-
codebase = parse_codebase(codebase_path)
51+
with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} ...", spinner="dots") as status:
52+
codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories)
4353
status.update("[bold green]✓ Parsed codebase")
4454

4555
status.update("[bold]Running codemod...")

src/codegen/cli/sdk/decorator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ def __init__(
1515
self,
1616
name: str,
1717
*,
18+
subdirectories: list[str] | None = None,
1819
webhook_config: dict | None = None,
1920
lint_mode: bool = False,
2021
lint_user_whitelist: Sequence[str] | None = None,
2122
):
2223
self.name = name
24+
self.subdirectories = subdirectories
2325
self.func: Callable | None = None
2426
self.params_type = None
2527
self.webhook_config = webhook_config
@@ -42,7 +44,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
4244
return wrapper
4345

4446

45-
def function(name: str) -> DecoratedFunction:
47+
def function(name: str, subdirectories: list[str] | None = None) -> DecoratedFunction:
4648
"""Decorator for codegen functions.
4749
4850
Args:
@@ -54,7 +56,7 @@ def run(codebase):
5456
pass
5557
5658
"""
57-
return DecoratedFunction(name)
59+
return DecoratedFunction(name=name, subdirectories=subdirectories)
5860

5961

6062
def webhook(

src/codegen/cli/utils/function_finder.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class DecoratedFunction:
1414
source: str
1515
lint_mode: bool
1616
lint_user_whitelist: list[str]
17+
subdirectories: list[str] | None = None
1718
filepath: Path | None = None
1819
parameters: list[tuple[str, str | None]] = dataclasses.field(default_factory=list)
1920
arguments_type_schema: dict | None = None
@@ -83,6 +84,20 @@ class CodegenFunctionVisitor(ast.NodeVisitor):
8384
def __init__(self):
8485
self.functions: list[DecoratedFunction] = []
8586

87+
def get_function_name(self, node: ast.Call) -> str:
88+
keywords = {k.arg: k.value for k in node.keywords}
89+
if "name" in keywords:
90+
return ast.literal_eval(keywords["name"])
91+
return ast.literal_eval(node.args[0])
92+
93+
def get_subdirectories(self, node: ast.Call) -> list[str] | None:
94+
keywords = {k.arg: k.value for k in node.keywords}
95+
if "subdirectories" in keywords:
96+
return ast.literal_eval(keywords["subdirectories"])
97+
if len(node.args) > 1:
98+
return ast.literal_eval(node.args[1])
99+
return None
100+
86101
def get_function_body(self, node: ast.FunctionDef) -> str:
87102
"""Extract and unindent the function body."""
88103
# Get the start and end positions of the function body
@@ -178,7 +193,7 @@ def visit_FunctionDef(self, node):
178193
for decorator in node.decorator_list:
179194
if (
180195
isinstance(decorator, ast.Call)
181-
and len(decorator.args) >= 1
196+
and (len(decorator.args) > 0 or len(decorator.keywords) > 0)
182197
and (
183198
# Check if it's a direct codegen.X call
184199
(isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and decorator.func.value.id == "codegen")
@@ -188,7 +203,8 @@ def visit_FunctionDef(self, node):
188203
)
189204
):
190205
# Get the function name from the decorator argument
191-
func_name = ast.literal_eval(decorator.args[0])
206+
func_name = self.get_function_name(decorator)
207+
subdirectories = self.get_subdirectories(decorator)
192208

193209
# Get additional metadata for webhook
194210
lint_mode = decorator.func.attr == "webhook"
@@ -201,7 +217,16 @@ def visit_FunctionDef(self, node):
201217
# Get just the function body, unindented
202218
body_source = self.get_function_body(node)
203219
parameters = self.get_function_parameters(node)
204-
self.functions.append(DecoratedFunction(name=func_name, source=body_source, lint_mode=lint_mode, lint_user_whitelist=lint_user_whitelist, parameters=parameters))
220+
self.functions.append(
221+
DecoratedFunction(
222+
name=func_name,
223+
subdirectories=subdirectories,
224+
source=body_source,
225+
lint_mode=lint_mode,
226+
lint_user_whitelist=lint_user_whitelist,
227+
parameters=parameters,
228+
)
229+
)
205230

206231
def _has_codegen_root(self, node):
207232
"""Recursively check if an AST node chain starts with codegen."""

0 commit comments

Comments
 (0)