1
+ import json
2
+ import os
3
+ from typing import Literal , Optional
4
+ from fastapi import APIRouter , FastAPI , Request
5
+ import modal
6
+ import logging
7
+ from codegen .extensions .events .codegen_app import CodegenApp
8
+ from codegen .extensions .events .modal .request_util import fastapi_request_adapter
9
+ from codegen .git .clients .git_repo_client import GitRepoClient
10
+ from codegen .git .schemas .repo_config import RepoConfig
11
+ from fastapi_utils .cbv import cbv
12
+
13
+
14
+ logging .basicConfig (level = logging .INFO , force = True )
15
+ logger = logging .getLogger (__name__ )
16
+
17
+ # refactor this to be a config
18
+ DEFAULT_SNAPSHOT_DICT_ID = "codegen-events-codebase-snapshots"
19
+
20
+
21
+ class EventRouterMixin :
22
+ """
23
+ This class is intended to be registered as a modal Class
24
+ and will be used to route events to the correct handler.
25
+
26
+ Usage:
27
+ @codegen_events_app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()])
28
+ class CustomEventAPI(EventRouterMixin):
29
+ pass
30
+
31
+ """
32
+
33
+ snapshot_index_id : str = DEFAULT_SNAPSHOT_DICT_ID
34
+
35
+
36
+
37
+ def get_event_handler_cls (self ) -> modal .Cls :
38
+ """ lookup the Modal Class where the event handlers are defined"""
39
+ raise NotImplementedError ("Subclasses must implement this method" )
40
+
41
+ async def handle_event (self , org : str , repo : str , provider : Literal ["slack" , "github" , "linear" ], request : Request ):
42
+
43
+ repo_config = RepoConfig (
44
+ name = repo ,
45
+ full_name = f"{ org } /{ repo } " ,
46
+ )
47
+
48
+ repo_snapshotdict = modal .Dict .from_name (self .snapshot_index_id , {}, create_if_missing = True )
49
+
50
+ last_snapshot_commit = repo_snapshotdict .get (f"{ org } /{ repo } " , None )
51
+
52
+ if last_snapshot_commit is None :
53
+
54
+ git_client = GitRepoClient (repo_config = repo_config , access_token = os .environ ["GITHUB_ACCESS_TOKEN" ])
55
+ branch = git_client .get_branch_safe (git_client .default_branch )
56
+ last_snapshot_commit = branch .commit .sha if branch and branch .commit else None
57
+
58
+ Klass = self .get_event_handler_cls ()
59
+ klass = Klass (repo_org = org , repo_name = repo , commit = last_snapshot_commit )
60
+
61
+ print (f"Repo info: org: { org } repo: { repo } commit: { last_snapshot_commit } " )
62
+ print ("DEBUG: " , await request .body ())
63
+ request_payload = await request .json ()
64
+ request_headers = dict (request .headers )
65
+ request_headers .pop ('host' , None ) # Remove host header if present
66
+
67
+
68
+ if provider == "slack" :
69
+ return klass .proxy_event .remote (f"{ org } /{ repo } /slack/events" , payload = request_payload , headers = request_headers )
70
+ elif provider == "github" :
71
+ return klass .proxy_event .remote (f"{ org } /{ repo } /github/events" , payload = request_payload , headers = request_headers )
72
+ elif provider == "linear" :
73
+ return klass .proxy_event .remote (f"{ org } /{ repo } /linear/events" , payload = request_payload , headers = request_headers )
74
+ else :
75
+ raise ValueError (f"Invalid provider: { provider } " )
76
+
77
+
78
+
79
+ def refresh_repository_snapshots (self , snapshot_index_id : str ):
80
+ """Refresh the latest snapshot for all repositories in the dictionary."""
81
+ # Get all repositories from the modal.Dict
82
+ repo_dict = modal .Dict .from_name (snapshot_index_id , {}, create_if_missing = True )
83
+
84
+ for repo_full_name in repo_dict .keys ():
85
+ try :
86
+ # Parse the repository full name to get org and repo
87
+ org , repo = repo_full_name .split ('/' )
88
+
89
+ # Create a RepoConfig for the repository
90
+ repo_config = RepoConfig (
91
+ name = repo ,
92
+ full_name = repo_full_name ,
93
+ )
94
+
95
+ # Initialize the GitRepoClient to fetch the latest commit
96
+ git_client = GitRepoClient (repo_config = repo_config , access_token = os .environ ["GITHUB_ACCESS_TOKEN" ])
97
+
98
+ # Get the default branch and its latest commit
99
+ branch = git_client .get_branch_safe (git_client .default_branch )
100
+ commit = branch .commit .sha if branch and branch .commit else None
101
+
102
+ if commit :
103
+ # Get the CodegenEventsApi class
104
+ Klass = self .get_event_handler_cls ()
105
+ # Create an instance with the latest commit
106
+ klass = Klass (repo_org = org , repo_name = repo , commit = commit )
107
+
108
+ # Ping the function to refresh the snapshot
109
+ result = klass .ping .remote ()
110
+
111
+ logging .info (f"Refreshed snapshot for { repo_full_name } with commit { commit } : { result } " )
112
+ else :
113
+ logging .warning (f"Could not fetch latest commit for { repo_full_name } " )
114
+
115
+ except Exception as e :
116
+ logging .error (f"Error refreshing snapshot for { repo_full_name } : { str (e )} " )
117
+
118
+
119
+
120
+
121
+ class CodebaseEventsApp :
122
+ """
123
+ This class is intended to be registered as a modal Class
124
+ and will be used to register event handlers for webhook events. It includes snapshotting behavior
125
+ and should be used with CodebaseEventsAPI.
126
+
127
+ Usage:
128
+ @app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()], enable_memory_snapshot=True, container_idle_timeout=300)
129
+ class YourCustomerEventsAPP(CodebaseEventsApp):
130
+ pass
131
+ """
132
+
133
+ commit : str = modal .parameter (default = "" )
134
+ repo_org : str = modal .parameter (default = "" )
135
+ repo_name : str = modal .parameter (default = "" )
136
+ snapshot_index_id : str = DEFAULT_SNAPSHOT_DICT_ID
137
+
138
+
139
+ def get_codegen_app (self ) -> CodegenApp :
140
+ full_repo_name = f"{ self .repo_org } /{ self .repo_name } "
141
+ return CodegenApp (name = f"{ full_repo_name } -events" , repo = full_repo_name , commit = self .commit )
142
+
143
+ @modal .enter (snap = True )
144
+ def load (self ):
145
+ self .cg = self .get_codegen_app ()
146
+ self .cg .parse_repo ()
147
+ self .setup_handlers (self .cg )
148
+
149
+
150
+ # TODO: if multiple snapshots are taken for the same commit, we will need to compare commit timestamps
151
+ snapshot_dict = modal .Dict .from_name (self .snapshot_index_id , {}, create_if_missing = True )
152
+ snapshot_dict .put (f"{ self .repo_org } /{ self .repo_name } " , self .commit )
153
+
154
+ def setup_handlers (self , cg : CodegenApp ):
155
+ raise NotImplementedError ("Subclasses must implement this method" )
156
+
157
+ @modal .method ()
158
+ async def proxy_event (self , route : str , payload : dict , headers : dict ):
159
+ logger .info (f"Handling event: { route } " )
160
+ request = await fastapi_request_adapter (payload = payload , headers = headers , route = route )
161
+
162
+ if "slack/events" in route :
163
+ response_data = await self .cg .handle_slack_event (request )
164
+ elif "github/events" in route :
165
+ response_data = await self .cg .handle_github_event (request )
166
+ elif "linear/events" in route :
167
+ response_data = await self .cg .handle_linear_event (request )
168
+ else :
169
+ raise ValueError (f"Invalid route: { route } " )
170
+
171
+ return response_data
172
+
173
+ @modal .method ()
174
+ def ping (self ):
175
+ logger .info (f"Pinging function with repo: { self .repo_org } /{ self .repo_name } commit: { self .commit } " )
176
+ return {"status" : "ok" }
177
+
178
+ @modal .asgi_app ()
179
+ def fastapi_endpoint (self ):
180
+ logger .info ("Serving FastAPI app from class method" )
181
+ return self .cg .app
182
+
0 commit comments