Skip to content

Commit f398107

Browse files
committed
fix: add to sdk
1 parent 2ebdb7c commit f398107

File tree

3 files changed

+230
-27
lines changed

3 files changed

+230
-27
lines changed

src/codegen/extensions/events/app.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import json
2+
import os
3+
from typing import Literal, Optional
4+
from fastapi import APIRouter, FastAPI, Request
5+
import modal
6+
import logging
7+
from codegen.extensions.events.codegen_app import CodegenApp
8+
from codegen.extensions.events.modal.request_util import fastapi_request_adapter
9+
from codegen.git.clients.git_repo_client import GitRepoClient
10+
from codegen.git.schemas.repo_config import RepoConfig
11+
from fastapi_utils.cbv import cbv
12+
13+
14+
logging.basicConfig(level=logging.INFO, force=True)
15+
logger = logging.getLogger(__name__)
16+
17+
# refactor this to be a config
18+
DEFAULT_SNAPSHOT_DICT_ID = "codegen-events-codebase-snapshots"
19+
20+
21+
class EventRouterMixin:
22+
"""
23+
This class is intended to be registered as a modal Class
24+
and will be used to route events to the correct handler.
25+
26+
Usage:
27+
@codegen_events_app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()])
28+
class CustomEventAPI(EventRouterMixin):
29+
pass
30+
31+
"""
32+
33+
snapshot_index_id: str = DEFAULT_SNAPSHOT_DICT_ID
34+
35+
36+
37+
def get_event_handler_cls(self) -> modal.Cls:
38+
""" lookup the Modal Class where the event handlers are defined"""
39+
raise NotImplementedError("Subclasses must implement this method")
40+
41+
async def handle_event(self, org: str, repo: str, provider: Literal["slack", "github", "linear"], request: Request):
42+
43+
repo_config = RepoConfig(
44+
name=repo,
45+
full_name=f"{org}/{repo}",
46+
)
47+
48+
repo_snapshotdict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True)
49+
50+
last_snapshot_commit = repo_snapshotdict.get(f"{org}/{repo}", None)
51+
52+
if last_snapshot_commit is None:
53+
54+
git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"])
55+
branch = git_client.get_branch_safe(git_client.default_branch)
56+
last_snapshot_commit = branch.commit.sha if branch and branch.commit else None
57+
58+
Klass = self.get_event_handler_cls()
59+
klass = Klass(repo_org=org, repo_name=repo, commit=last_snapshot_commit)
60+
61+
print(f"Repo info: org: {org} repo: {repo} commit: {last_snapshot_commit}")
62+
print("DEBUG: ", await request.body())
63+
request_payload = await request.json()
64+
request_headers = dict(request.headers)
65+
request_headers.pop('host', None) # Remove host header if present
66+
67+
68+
if provider == "slack":
69+
return klass.proxy_event.remote(f"{org}/{repo}/slack/events", payload=request_payload, headers=request_headers)
70+
elif provider == "github":
71+
return klass.proxy_event.remote(f"{org}/{repo}/github/events", payload=request_payload, headers=request_headers)
72+
elif provider == "linear":
73+
return klass.proxy_event.remote(f"{org}/{repo}/linear/events", payload=request_payload, headers=request_headers)
74+
else:
75+
raise ValueError(f"Invalid provider: {provider}")
76+
77+
78+
79+
def refresh_repository_snapshots(self, snapshot_index_id: str):
80+
"""Refresh the latest snapshot for all repositories in the dictionary."""
81+
# Get all repositories from the modal.Dict
82+
repo_dict = modal.Dict.from_name(snapshot_index_id, {}, create_if_missing=True)
83+
84+
for repo_full_name in repo_dict.keys():
85+
try:
86+
# Parse the repository full name to get org and repo
87+
org, repo = repo_full_name.split('/')
88+
89+
# Create a RepoConfig for the repository
90+
repo_config = RepoConfig(
91+
name=repo,
92+
full_name=repo_full_name,
93+
)
94+
95+
# Initialize the GitRepoClient to fetch the latest commit
96+
git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"])
97+
98+
# Get the default branch and its latest commit
99+
branch = git_client.get_branch_safe(git_client.default_branch)
100+
commit = branch.commit.sha if branch and branch.commit else None
101+
102+
if commit:
103+
# Get the CodegenEventsApi class
104+
Klass = self.get_event_handler_cls()
105+
# Create an instance with the latest commit
106+
klass = Klass(repo_org=org, repo_name=repo, commit=commit)
107+
108+
# Ping the function to refresh the snapshot
109+
result = klass.ping.remote()
110+
111+
logging.info(f"Refreshed snapshot for {repo_full_name} with commit {commit}: {result}")
112+
else:
113+
logging.warning(f"Could not fetch latest commit for {repo_full_name}")
114+
115+
except Exception as e:
116+
logging.error(f"Error refreshing snapshot for {repo_full_name}: {str(e)}")
117+
118+
119+
120+
121+
class CodebaseEventsApp:
122+
"""
123+
This class is intended to be registered as a modal Class
124+
and will be used to register event handlers for webhook events. It includes snapshotting behavior
125+
and should be used with CodebaseEventsAPI.
126+
127+
Usage:
128+
@app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()], enable_memory_snapshot=True, container_idle_timeout=300)
129+
class YourCustomerEventsAPP(CodebaseEventsApp):
130+
pass
131+
"""
132+
133+
commit: str = modal.parameter(default="")
134+
repo_org: str = modal.parameter(default="")
135+
repo_name: str = modal.parameter(default="")
136+
snapshot_index_id: str = DEFAULT_SNAPSHOT_DICT_ID
137+
138+
139+
def get_codegen_app(self) -> CodegenApp:
140+
full_repo_name = f"{self.repo_org}/{self.repo_name}"
141+
return CodegenApp(name=f"{full_repo_name}-events", repo=full_repo_name, commit=self.commit)
142+
143+
@modal.enter(snap=True)
144+
def load(self):
145+
self.cg = self.get_codegen_app()
146+
self.cg.parse_repo()
147+
self.setup_handlers(self.cg)
148+
149+
150+
# TODO: if multiple snapshots are taken for the same commit, we will need to compare commit timestamps
151+
snapshot_dict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True)
152+
snapshot_dict.put(f"{self.repo_org}/{self.repo_name}", self.commit)
153+
154+
def setup_handlers(self, cg: CodegenApp):
155+
raise NotImplementedError("Subclasses must implement this method")
156+
157+
@modal.method()
158+
async def proxy_event(self, route: str, payload: dict, headers: dict):
159+
logger.info(f"Handling event: {route}")
160+
request = await fastapi_request_adapter(payload=payload, headers=headers, route=route)
161+
162+
if "slack/events" in route:
163+
response_data = await self.cg.handle_slack_event(request)
164+
elif "github/events" in route:
165+
response_data = await self.cg.handle_github_event(request)
166+
elif "linear/events" in route:
167+
response_data = await self.cg.handle_linear_event(request)
168+
else:
169+
raise ValueError(f"Invalid route: {route}")
170+
171+
return response_data
172+
173+
@modal.method()
174+
def ping(self):
175+
logger.info(f"Pinging function with repo: {self.repo_org}/{self.repo_name} commit: {self.commit}")
176+
return {"status": "ok"}
177+
178+
@modal.asgi_app()
179+
def fastapi_endpoint(self):
180+
logger.info("Serving FastAPI app from class method")
181+
return self.cg.app
182+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from fastapi import Request as FastAPIRequest
3+
from requests import Request
4+
5+
6+
7+
async def fastapi_request_adapter(payload: dict, headers: dict, route: str) -> FastAPIRequest:
8+
9+
# Create a FastAPI Request object from the payload and headers
10+
# 1. Create the scope dictionary
11+
scope = {
12+
"type": "http",
13+
"method": "POST",
14+
"path": f"/{route}",
15+
"raw_path": f"/{route}".encode(),
16+
"query_string": b"",
17+
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
18+
"client": ("127.0.0.1", 0), # Default client address
19+
}
20+
21+
# 2. Create a receive function that returns the request body
22+
body_bytes = json.dumps(payload).encode()
23+
24+
async def receive():
25+
return {
26+
"type": "http.request",
27+
"body": body_bytes,
28+
"more_body": False,
29+
}
30+
31+
# 3. Create a send function to capture the response
32+
response_body = []
33+
response_status = None
34+
response_headers = None
35+
36+
async def send(message):
37+
nonlocal response_status, response_headers
38+
39+
if message["type"] == "http.response.start":
40+
response_status = message["status"]
41+
response_headers = message["headers"]
42+
elif message["type"] == "http.response.body":
43+
response_body.append(message.get("body", b""))
44+
45+
# 4. Create the request object
46+
fastapi_request = FastAPIRequest(scope, receive)
47+
48+
return fastapi_request

0 commit comments

Comments
 (0)