Skip to content

Commit 3445737

Browse files
fix: determine programming language if not set (#678)
1 parent bed8eff commit 3445737

File tree

6 files changed

+51
-15
lines changed

6 files changed

+51
-15
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import codegen
2+
from codegen.sdk.core.codebase import Codebase
3+
from codegen.shared.enums.programming_language import ProgrammingLanguage
4+
5+
6+
@codegen.function("test-language", subdirectories=["src/codegen/cli"], language=ProgrammingLanguage.PYTHON)
7+
def run(codebase: Codebase):
8+
file = codebase.get_file("src/codegen/cli/errors.py")
9+
print(f"File: {file.path}")
10+
for s in file.symbols:
11+
print(s.name)
12+
13+
14+
if __name__ == "__main__":
15+
print("Parsing codebase...")
16+
codebase = Codebase("./")
17+
18+
print("Running...")
19+
run(codebase)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ dependencies = [
6464
"langchain_core",
6565
"langchain_openai",
6666
"langgraph",
67+
"langgraph-prebuilt",
6768
"numpy>=2.2.2",
6869
"mcp[cli]",
6970
"neo4j",
@@ -239,6 +240,7 @@ DEP002 = [
239240
"pip",
240241
"python-levenshtein",
241242
"pytest-snapshot",
243+
"langgraph-prebuilt",
242244
]
243245
DEP003 = "sqlalchemy"
244246
DEP004 = "pytest"

src/codegen/cli/commands/list/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def list_command():
1818
table.add_column("Type", style="magenta")
1919
table.add_column("Path", style="dim")
2020
table.add_column("Subdirectories", style="dim")
21+
table.add_column("Language", style="dim")
2122

2223
for func in functions:
2324
func_type = "Webhook" if func.lint_mode else "Function"
@@ -26,6 +27,7 @@ def list_command():
2627
func_type,
2728
str(func.filepath.relative_to(Path.cwd())) if func.filepath else "<unknown>",
2829
", ".join(func.subdirectories) if func.subdirectories else "",
30+
func.language or "",
2931
)
3032

3133
rich.print(table)

src/codegen/cli/commands/run/run_local.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
from codegen.cli.utils.function_finder import DecoratedFunction
99
from codegen.git.repo_operator.repo_operator import RepoOperator
1010
from codegen.git.schemas.repo_config import RepoConfig
11+
from codegen.git.utils.language import determine_project_language
1112
from codegen.sdk.codebase.config import ProjectConfig
1213
from codegen.sdk.core.codebase import Codebase
14+
from codegen.shared.enums.programming_language import ProgrammingLanguage
1315

1416

1517
def parse_codebase(
1618
repo_path: Path,
1719
subdirectories: list[str] | None = None,
20+
language: ProgrammingLanguage | None = None,
1821
) -> Codebase:
1922
"""Parse the codebase at the given root.
2023
@@ -29,6 +32,7 @@ def parse_codebase(
2932
ProjectConfig(
3033
repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)),
3134
subdirectories=subdirectories,
35+
programming_language=language or determine_project_language(repo_path),
3236
)
3337
]
3438
)
@@ -48,8 +52,8 @@ def run_local(
4852
diff_preview: Number of lines of diff to preview (None for all)
4953
"""
5054
# Parse codebase and run
51-
with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} ...", spinner="dots") as status:
52-
codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories)
55+
with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} and language {function.language or 'AUTO'} ...", spinner="dots") as status:
56+
codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories, language=function.language)
5357
status.update("[bold green]✓ Parsed codebase")
5458

5559
status.update("[bold]Running codemod...")

src/codegen/cli/sdk/decorator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import wraps
33
from typing import Literal, ParamSpec, TypeVar, get_type_hints
44

5+
from codegen.shared.enums.programming_language import ProgrammingLanguage
6+
57
P = ParamSpec("P")
68
T = TypeVar("T")
79
WebhookType = Literal["pr", "push", "issue", "release"]
@@ -16,12 +18,14 @@ def __init__(
1618
name: str,
1719
*,
1820
subdirectories: list[str] | None = None,
21+
language: ProgrammingLanguage | None = None,
1922
webhook_config: dict | None = None,
2023
lint_mode: bool = False,
2124
lint_user_whitelist: Sequence[str] | None = None,
2225
):
2326
self.name = name
2427
self.subdirectories = subdirectories
28+
self.language = language
2529
self.func: Callable | None = None
2630
self.params_type = None
2731
self.webhook_config = webhook_config
@@ -44,7 +48,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
4448
return wrapper
4549

4650

47-
def function(name: str, subdirectories: list[str] | None = None) -> DecoratedFunction:
51+
def function(name: str, subdirectories: list[str] | None = None, language: ProgrammingLanguage | None = None) -> DecoratedFunction:
4852
"""Decorator for codegen functions.
4953
5054
Args:
@@ -56,7 +60,7 @@ def run(codebase):
5660
pass
5761
5862
"""
59-
return DecoratedFunction(name=name, subdirectories=subdirectories)
63+
return DecoratedFunction(name=name, subdirectories=subdirectories, language=language)
6064

6165

6266
def webhook(

src/codegen/cli/utils/function_finder.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dataclasses import dataclass
66
from pathlib import Path
77

8+
from codegen.shared.enums.programming_language import ProgrammingLanguage
9+
810

911
@dataclass
1012
class DecoratedFunction:
@@ -15,6 +17,7 @@ class DecoratedFunction:
1517
lint_mode: bool
1618
lint_user_whitelist: list[str]
1719
subdirectories: list[str] | None = None
20+
language: ProgrammingLanguage | None = None
1821
filepath: Path | None = None
1922
parameters: list[tuple[str, str | None]] = dataclasses.field(default_factory=list)
2023
arguments_type_schema: dict | None = None
@@ -98,6 +101,14 @@ def get_subdirectories(self, node: ast.Call) -> list[str] | None:
98101
return ast.literal_eval(node.args[1])
99102
return None
100103

104+
def get_language(self, node: ast.Call) -> ProgrammingLanguage | None:
105+
keywords = {k.arg: k.value for k in node.keywords}
106+
if "language" in keywords:
107+
return ProgrammingLanguage(keywords["language"].attr)
108+
if len(node.args) > 2:
109+
return ast.literal_eval(node.args[2])
110+
return None
111+
101112
def get_function_body(self, node: ast.FunctionDef) -> str:
102113
"""Extract and unindent the function body."""
103114
# Get the start and end positions of the function body
@@ -202,10 +213,6 @@ def visit_FunctionDef(self, node):
202213
(isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Attribute) and self._has_codegen_root(decorator.func.value))
203214
)
204215
):
205-
# Get the function name from the decorator argument
206-
func_name = self.get_function_name(decorator)
207-
subdirectories = self.get_subdirectories(decorator)
208-
209216
# Get additional metadata for webhook
210217
lint_mode = decorator.func.attr == "webhook"
211218
lint_user_whitelist = []
@@ -214,17 +221,15 @@ def visit_FunctionDef(self, node):
214221
if keyword.arg == "users" and isinstance(keyword.value, ast.List):
215222
lint_user_whitelist = [ast.literal_eval(elt).lstrip("@") for elt in keyword.value.elts]
216223

217-
# Get just the function body, unindented
218-
body_source = self.get_function_body(node)
219-
parameters = self.get_function_parameters(node)
220224
self.functions.append(
221225
DecoratedFunction(
222-
name=func_name,
223-
subdirectories=subdirectories,
224-
source=body_source,
226+
name=self.get_function_name(decorator),
227+
subdirectories=self.get_subdirectories(decorator),
228+
language=self.get_language(decorator),
229+
source=self.get_function_body(node),
225230
lint_mode=lint_mode,
226231
lint_user_whitelist=lint_user_whitelist,
227-
parameters=parameters,
232+
parameters=self.get_function_parameters(node),
228233
)
229234
)
230235

0 commit comments

Comments
 (0)