Skip to content

Commit e695b31

Browse files
authored
feat: add CLI (#45)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed - [ ] I have read and agree to the [Contributor License Agreement](../CLA.md)
1 parent fcaccfc commit e695b31

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+3806
-78
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ dependencies = [
5353
"pyinstrument>=5.0.0",
5454
"pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work!
5555
"emoji>=2.14.0",
56+
"rich-click>=1.8.5",
57+
"python-dotenv>=1.0.1",
58+
"giturlparse",
59+
"pygit2>=1.16.0",
60+
"unidiff>=0.7.5",
61+
"datamodel-code-generator>=0.26.5",
62+
"toml>=0.10.2",
5663
"PyGithub==2.5.0",
5764
"GitPython==3.1.44",
5865
]
@@ -62,8 +69,10 @@ classifiers = [
6269
"Intended Audience :: Developers", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Software Development", "Development Status :: 4 - Beta", "Environment :: MacOS X", "Programming Language :: Python :: 3", "Programming Language :: Python", ]
6370

6471
[project.scripts]
72+
codegen = "codegen.cli.cli:main"
6573
gs = "codegen.gscli.main:main"
6674
run_string = "codegen.sdk.core.main:main"
75+
6776
[project.optional-dependencies]
6877
types = [
6978
"types-networkx>=3.2.1.20240918",

src/codegen/cli/__init__.py

Whitespace-only changes.

src/codegen/cli/_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ENV = ""

src/codegen/cli/api/client.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import json
2+
from typing import ClassVar, TypeVar
3+
4+
import requests
5+
from pydantic import BaseModel
6+
from rich import print as rprint
7+
8+
from codegen.cli.api.endpoints import (
9+
CREATE_ENDPOINT,
10+
DEPLOY_ENDPOINT,
11+
DOCS_ENDPOINT,
12+
EXPERT_ENDPOINT,
13+
IDENTIFY_ENDPOINT,
14+
LOOKUP_ENDPOINT,
15+
PR_LOOKUP_ENDPOINT,
16+
RUN_ENDPOINT,
17+
RUN_ON_PR_ENDPOINT,
18+
)
19+
from codegen.cli.api.schemas import (
20+
AskExpertInput,
21+
AskExpertResponse,
22+
CodemodRunType,
23+
CreateInput,
24+
CreateResponse,
25+
DeployInput,
26+
DeployResponse,
27+
DocsInput,
28+
DocsResponse,
29+
IdentifyResponse,
30+
LookupInput,
31+
LookupOutput,
32+
PRLookupInput,
33+
PRLookupResponse,
34+
PRSchema,
35+
RunCodemodInput,
36+
RunCodemodOutput,
37+
RunOnPRInput,
38+
RunOnPRResponse,
39+
)
40+
from codegen.cli.auth.session import CodegenSession
41+
from codegen.cli.codemod.convert import convert_to_ui
42+
from codegen.cli.env.global_env import global_env
43+
from codegen.cli.errors import InvalidTokenError, ServerError
44+
from codegen.cli.utils.codemods import Codemod
45+
from codegen.cli.utils.function_finder import DecoratedFunction
46+
47+
InputT = TypeVar("InputT", bound=BaseModel)
48+
OutputT = TypeVar("OutputT", bound=BaseModel)
49+
50+
51+
class RestAPI:
52+
"""Handles auth + validation with the codegen API."""
53+
54+
_session: ClassVar[requests.Session] = requests.Session()
55+
56+
auth_token: str | None = None
57+
58+
def __init__(self, auth_token: str):
59+
self.auth_token = auth_token
60+
61+
def _get_headers(self) -> dict[str, str]:
62+
"""Get headers with authentication token."""
63+
return {"Authorization": f"Bearer {self.auth_token}"}
64+
65+
def _make_request(
66+
self,
67+
method: str,
68+
endpoint: str,
69+
input_data: InputT | None,
70+
output_model: type[OutputT],
71+
) -> OutputT:
72+
"""Make an API request with input validation and response handling."""
73+
if global_env.DEBUG:
74+
rprint(f"[purple]{method}[/purple] {endpoint}")
75+
if input_data:
76+
rprint(f"{json.dumps(input_data.model_dump(), indent=4)}")
77+
78+
try:
79+
headers = self._get_headers()
80+
81+
json_data = input_data.model_dump() if input_data else None
82+
83+
response = self._session.request(
84+
method,
85+
endpoint,
86+
json=json_data,
87+
headers=headers,
88+
)
89+
90+
if response.status_code == 200:
91+
try:
92+
return output_model.model_validate(response.json())
93+
except ValueError as e:
94+
raise ServerError(f"Invalid response format: {e}")
95+
elif response.status_code == 401:
96+
raise InvalidTokenError("Invalid or expired authentication token")
97+
elif response.status_code == 500:
98+
raise ServerError("The server encountered an error while processing your request")
99+
else:
100+
try:
101+
error_json = response.json()
102+
error_msg = error_json.get("detail", error_json)
103+
except Exception:
104+
error_msg = response.text
105+
raise ServerError(f"Error ({response.status_code}): {error_msg}")
106+
107+
except requests.RequestException as e:
108+
raise ServerError(f"Network error: {e!s}")
109+
110+
def run(
111+
self,
112+
function: DecoratedFunction | Codemod,
113+
include_source: bool = True,
114+
run_type: CodemodRunType = CodemodRunType.DIFF,
115+
template_context: dict[str, str] | None = None,
116+
) -> RunCodemodOutput:
117+
"""Run a codemod transformation.
118+
119+
Args:
120+
function: The function or codemod to run
121+
include_source: Whether to include the source code in the request.
122+
If False, uses the deployed version.
123+
run_type: Type of run (diff or pr)
124+
template_context: Context variables to pass to the codemod
125+
126+
"""
127+
session = CodegenSession()
128+
129+
base_input = {
130+
"codemod_name": function.name,
131+
"repo_full_name": session.repo_name,
132+
"codemod_run_type": run_type,
133+
}
134+
135+
# Only include source if requested
136+
if include_source:
137+
source = function.get_current_source() if isinstance(function, Codemod) else function.source
138+
base_input["codemod_source"] = convert_to_ui(source)
139+
140+
# Add template context if provided
141+
if template_context:
142+
base_input["template_context"] = template_context
143+
144+
input_data = RunCodemodInput(input=RunCodemodInput.BaseRunCodemodInput(**base_input))
145+
return self._make_request(
146+
"POST",
147+
RUN_ENDPOINT,
148+
input_data,
149+
RunCodemodOutput,
150+
)
151+
152+
def get_docs(self) -> dict:
153+
"""Search documentation."""
154+
session = CodegenSession()
155+
return self._make_request(
156+
"GET",
157+
DOCS_ENDPOINT,
158+
DocsInput(docs_input=DocsInput.BaseDocsInput(repo_full_name=session.repo_name)),
159+
DocsResponse,
160+
)
161+
162+
def ask_expert(self, query: str) -> AskExpertResponse:
163+
"""Ask the expert system a question."""
164+
return self._make_request(
165+
"GET",
166+
EXPERT_ENDPOINT,
167+
AskExpertInput(input=AskExpertInput.BaseAskExpertInput(query=query)),
168+
AskExpertResponse,
169+
)
170+
171+
def create(self, name: str, query: str) -> CreateResponse:
172+
"""Get AI-generated starter code for a codemod."""
173+
session = CodegenSession()
174+
return self._make_request(
175+
"GET",
176+
CREATE_ENDPOINT,
177+
CreateInput(input=CreateInput.BaseCreateInput(name=name, query=query, repo_full_name=session.repo_name)),
178+
CreateResponse,
179+
)
180+
181+
def identify(self) -> IdentifyResponse | None:
182+
"""Identify the user's codemod."""
183+
return self._make_request(
184+
"POST",
185+
IDENTIFY_ENDPOINT,
186+
None,
187+
IdentifyResponse,
188+
)
189+
190+
def deploy(
191+
self, codemod_name: str, codemod_source: str, lint_mode: bool = False, lint_user_whitelist: list[str] | None = None, message: str | None = None, arguments_schema: dict | None = None
192+
) -> DeployResponse:
193+
"""Deploy a codemod to the Modal backend."""
194+
session = CodegenSession()
195+
return self._make_request(
196+
"POST",
197+
DEPLOY_ENDPOINT,
198+
DeployInput(
199+
input=DeployInput.BaseDeployInput(
200+
codemod_name=codemod_name,
201+
codemod_source=codemod_source,
202+
repo_full_name=session.repo_name,
203+
lint_mode=lint_mode,
204+
lint_user_whitelist=lint_user_whitelist or [],
205+
message=message,
206+
arguments_schema=arguments_schema,
207+
)
208+
),
209+
DeployResponse,
210+
)
211+
212+
def lookup(self, codemod_name: str) -> LookupOutput:
213+
"""Look up a codemod by name."""
214+
session = CodegenSession()
215+
return self._make_request(
216+
"GET",
217+
LOOKUP_ENDPOINT,
218+
LookupInput(input=LookupInput.BaseLookupInput(codemod_name=codemod_name, repo_full_name=session.repo_name)),
219+
LookupOutput,
220+
)
221+
222+
def run_on_pr(self, codemod_name: str, repo_full_name: str, github_pr_number: int, language: str | None = None) -> RunOnPRResponse:
223+
"""Test a webhook against a specific PR."""
224+
return self._make_request(
225+
"POST",
226+
RUN_ON_PR_ENDPOINT,
227+
RunOnPRInput(
228+
input=RunOnPRInput.BaseRunOnPRInput(
229+
codemod_name=codemod_name,
230+
repo_full_name=repo_full_name,
231+
github_pr_number=github_pr_number,
232+
language=language,
233+
)
234+
),
235+
RunOnPRResponse,
236+
)
237+
238+
def lookup_pr(self, repo_full_name: str, github_pr_number: int) -> PRSchema:
239+
"""Look up a PR by repository and PR number."""
240+
return self._make_request(
241+
"GET",
242+
PR_LOOKUP_ENDPOINT,
243+
PRLookupInput(input=PRLookupInput.BasePRLookupInput(repo_full_name=repo_full_name, github_pr_number=github_pr_number)),
244+
PRLookupResponse,
245+
)

src/codegen/cli/api/endpoints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from codegen.cli.api.modal import MODAL_PREFIX
2+
3+
RUN_ENDPOINT = f"https://{MODAL_PREFIX}--cli-run.modal.run"
4+
DOCS_ENDPOINT = f"https://{MODAL_PREFIX}--cli-docs.modal.run"
5+
EXPERT_ENDPOINT = f"https://{MODAL_PREFIX}--cli-ask-expert.modal.run"
6+
IDENTIFY_ENDPOINT = f"https://{MODAL_PREFIX}--cli-identify.modal.run"
7+
CREATE_ENDPOINT = f"https://{MODAL_PREFIX}--cli-create.modal.run"
8+
DEPLOY_ENDPOINT = f"https://{MODAL_PREFIX}--cli-deploy.modal.run"
9+
LOOKUP_ENDPOINT = f"https://{MODAL_PREFIX}--cli-lookup.modal.run"
10+
RUN_ON_PR_ENDPOINT = f"https://{MODAL_PREFIX}--cli-run-on-pull-request.modal.run"
11+
PR_LOOKUP_ENDPOINT = f"https://{MODAL_PREFIX}--cli-pr-lookup.modal.run"

src/codegen/cli/api/modal.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from codegen.cli.env.enums import Environment
2+
from codegen.cli.env.global_env import global_env
3+
4+
5+
def get_modal_workspace():
6+
match global_env.ENV:
7+
case Environment.PRODUCTION:
8+
return "codegen-sh"
9+
case Environment.STAGING:
10+
return "codegen-sh-staging"
11+
case Environment.DEVELOP:
12+
return "codegen-sh-develop"
13+
case _:
14+
raise ValueError(f"Invalid environment: {global_env.ENV}")
15+
16+
17+
def get_modal_prefix():
18+
workspace = get_modal_workspace()
19+
if global_env.ENV == Environment.DEVELOP and global_env.MODAL_ENVIRONMENT:
20+
return f"{workspace}-{global_env.MODAL_ENVIRONMENT}"
21+
return workspace
22+
23+
24+
MODAL_PREFIX = get_modal_prefix()

0 commit comments

Comments
 (0)